pax_global_header00006660000000000000000000000064142371372110014513gustar00rootroot0000000000000052 comment=048d50c8725305567469eafb0de3e07a82e65b59 tofu-0.12.0/000077500000000000000000000000001423713721100125505ustar00rootroot00000000000000tofu-0.12.0/.gitignore000066400000000000000000000000741423713721100145410ustar00rootroot00000000000000*.pyc build/ dist/ *.egg-info/ install_manifest*.txt .idea/ tofu-0.12.0/LICENSE000066400000000000000000000167431423713721100135700ustar00rootroot00000000000000 GNU LESSER GENERAL PUBLIC LICENSE Version 3, 29 June 2007 Copyright (C) 2007 Free Software Foundation, Inc. Everyone is permitted to copy and distribute verbatim copies of this license document, but changing it is not allowed. This version of the GNU Lesser General Public License incorporates the terms and conditions of version 3 of the GNU General Public License, supplemented by the additional permissions listed below. 0. Additional Definitions. As used herein, "this License" refers to version 3 of the GNU Lesser General Public License, and the "GNU GPL" refers to version 3 of the GNU General Public License. "The Library" refers to a covered work governed by this License, other than an Application or a Combined Work as defined below. An "Application" is any work that makes use of an interface provided by the Library, but which is not otherwise based on the Library. Defining a subclass of a class defined by the Library is deemed a mode of using an interface provided by the Library. A "Combined Work" is a work produced by combining or linking an Application with the Library. The particular version of the Library with which the Combined Work was made is also called the "Linked Version". The "Minimal Corresponding Source" for a Combined Work means the Corresponding Source for the Combined Work, excluding any source code for portions of the Combined Work that, considered in isolation, are based on the Application, and not on the Linked Version. The "Corresponding Application Code" for a Combined Work means the object code and/or source code for the Application, including any data and utility programs needed for reproducing the Combined Work from the Application, but excluding the System Libraries of the Combined Work. 1. Exception to Section 3 of the GNU GPL. You may convey a covered work under sections 3 and 4 of this License without being bound by section 3 of the GNU GPL. 2. Conveying Modified Versions. If you modify a copy of the Library, and, in your modifications, a facility refers to a function or data to be supplied by an Application that uses the facility (other than as an argument passed when the facility is invoked), then you may convey a copy of the modified version: a) under this License, provided that you make a good faith effort to ensure that, in the event an Application does not supply the function or data, the facility still operates, and performs whatever part of its purpose remains meaningful, or b) under the GNU GPL, with none of the additional permissions of this License applicable to that copy. 3. Object Code Incorporating Material from Library Header Files. The object code form of an Application may incorporate material from a header file that is part of the Library. You may convey such object code under terms of your choice, provided that, if the incorporated material is not limited to numerical parameters, data structure layouts and accessors, or small macros, inline functions and templates (ten or fewer lines in length), you do both of the following: a) Give prominent notice with each copy of the object code that the Library is used in it and that the Library and its use are covered by this License. b) Accompany the object code with a copy of the GNU GPL and this license document. 4. Combined Works. You may convey a Combined Work under terms of your choice that, taken together, effectively do not restrict modification of the portions of the Library contained in the Combined Work and reverse engineering for debugging such modifications, if you also do each of the following: a) Give prominent notice with each copy of the Combined Work that the Library is used in it and that the Library and its use are covered by this License. b) Accompany the Combined Work with a copy of the GNU GPL and this license document. c) For a Combined Work that displays copyright notices during execution, include the copyright notice for the Library among these notices, as well as a reference directing the user to the copies of the GNU GPL and this license document. d) Do one of the following: 0) Convey the Minimal Corresponding Source under the terms of this License, and the Corresponding Application Code in a form suitable for, and under terms that permit, the user to recombine or relink the Application with a modified version of the Linked Version to produce a modified Combined Work, in the manner specified by section 6 of the GNU GPL for conveying Corresponding Source. 1) Use a suitable shared library mechanism for linking with the Library. A suitable mechanism is one that (a) uses at run time a copy of the Library already present on the user's computer system, and (b) will operate properly with a modified version of the Library that is interface-compatible with the Linked Version. e) Provide Installation Information, but only if you would otherwise be required to provide such information under section 6 of the GNU GPL, and only to the extent that such information is necessary to install and execute a modified version of the Combined Work produced by recombining or relinking the Application with a modified version of the Linked Version. (If you use option 4d0, the Installation Information must accompany the Minimal Corresponding Source and Corresponding Application Code. If you use option 4d1, you must provide the Installation Information in the manner specified by section 6 of the GNU GPL for conveying Corresponding Source.) 5. Combined Libraries. You may place library facilities that are a work based on the Library side by side in a single library together with other library facilities that are not Applications and are not covered by this License, and convey such a combined library under terms of your choice, if you do both of the following: a) Accompany the combined library with a copy of the same work based on the Library, uncombined with any other library facilities, conveyed under the terms of this License. b) Give prominent notice with the combined library that part of it is a work based on the Library, and explaining where to find the accompanying uncombined form of the same work. 6. Revised Versions of the GNU Lesser General Public License. The Free Software Foundation may publish revised and/or new versions of the GNU Lesser General Public License from time to time. Such new versions will be similar in spirit to the present version, but may differ in detail to address new problems or concerns. Each version is given a distinguishing version number. If the Library as you received it specifies that a certain numbered version of the GNU Lesser General Public License "or any later version" applies to it, you have the option of following the terms and conditions either of that published version or of any later version published by the Free Software Foundation. If the Library as you received it does not specify a version number of the GNU Lesser General Public License, you may choose any version of the GNU Lesser General Public License ever published by the Free Software Foundation. If the Library as you received it specifies that a proxy can decide whether future versions of the GNU Lesser General Public License shall apply, that proxy's public statement of acceptance of any version is permanent authorization for you to choose that version for the Library. tofu-0.12.0/MANIFEST.in000066400000000000000000000000471423713721100143070ustar00rootroot00000000000000include pkgconfig.py include README.md tofu-0.12.0/README.md000066400000000000000000000065171423713721100140400ustar00rootroot00000000000000## About This repository contains Python data processing scripts to be used with the UFO framework. At the moment they are targeted at high-performance reconstruction of tomographic data sets. ## Installation Run pip install -r requirements-guis.txt # If you want to use flow or ez python setup.py install in a prepared virtualenv or as root for system-wide installation. Note, that if you do plan to use the graphical user interface you need PyQt4, pyqtgraph and PyOpenGL. You are strongly advised to install PyQt through your system package manager, you can install pyqtgraph and PyOpenGL using the pip package manager though: pip install pyqtgraph PyOpenGL ## Usage ### Flow `tofu flow` is a visual flow programming tool. You can create a flow by using any task from [ufo-filters](https://github.com/ufo-kit/ufo-filters) and execute it. In includes visualization of 2D and 3D results, so you can quickly check the output of your flow, which is useful for finding algorithm parameters. ![flow](https://user-images.githubusercontent.com/2648829/150096902-fdbf1b7e-b34e-4368-98ac-c924cad8a6cd.jpg) ### Reconstruction To do a tomographic reconstruction you simply call $ tofu tomo --sinograms $PATH_TO_SINOGRAMS from the command line. To get get correct results, you may need to append options such as `--axis-pos/-a` and `--angle-step/-a` (which are given in radians!). Input paths are either directories or glob patterns. Output paths are either directories or a format that contains one `%i` [specifier](http://www.pixelbeat.org/programming/gcc/format_specs.html): $ tofu tomo --axis-pos=123.4 --angle-step=0.000123 \ --sinograms="/foo/bar/*.tif" --output="/output/slices-%05i.tif" You can get a help for all options by running $ tofu tomo --help and more verbose output by running with the `-v/--verbose` flag. You can also load reconstruction parameters from a configuration file called `reco.conf`. You may create a template with $ tofu init Note, that options passed via the command line always override configuration parameters! Besides scripted reconstructions, one can also run a standalone GUI for both reconstruction and quick assessment of the reconstructed data via $ tofu gui ![GUI](https://cloud.githubusercontent.com/assets/115270/6442540/db0b55fe-c0f0-11e4-9577-0048fddae8b7.png) ### Performance measurement If you are running at least ufo-core/filters 0.6, you can evaluate the performance of the filtered backprojection (without sinogram transposition!), with $ tofu perf You can customize parameter scans, pretty easily via $ tofu perf --width 256:8192:256 --height 512 which will reconstruct all combinations of width between 256 and 8192 with a step of 256 and a fixed height of 512 pixels. ### Estimating the center of rotation If you do not know the correct center of rotation from your experimental setup, you can estimate it with: $ tofu estimate -i $PATH_TO_SINOGRAMS Currently, a modified algorithm based on the work of [Donath et al.](http://dx.doi.org/10.1364/JOSAA.23.001048) is used to determine the center. ## Citation If you use this software for publishing your data, we kindly ask to cite the article below. Faragó, T., Gasilov, S., Emslie, I., Zuber, M., Helfen, L., Vogelgesang, M. & Baumbach, T. (2022). J. Synchrotron Rad. 29, https://doi.org/10.1107/S160057752200282X tofu-0.12.0/bin/000077500000000000000000000000001423713721100133205ustar00rootroot00000000000000tofu-0.12.0/bin/tofu000077500000000000000000000172631423713721100142340ustar00rootroot00000000000000#!/usr/bin/env python3 import os import sys import argparse import logging import time import re import gi from tofu import config, __version__ gi.require_version('Ufo', '0.0') LOG = logging.getLogger('tofu') def init(args): if not os.path.exists(args.config): config.write(args.config) else: raise RuntimeError("{0} already exists".format(args.config)) def run_tomo(args): from tofu import reco reco.tomo(args) def run_lamino(args): from tofu import lamino lamino.lamino(args) def run_genreco(args): from tofu import genreco genreco.genreco(args) def run_flat_correct(args): from tofu import preprocess preprocess.run_flat_correct(args) def run_preprocessing(args): from tofu import preprocess preprocess.run_preprocessing(args) def run_sinos(args): from tofu import preprocess preprocess.run_sinogram_generation(args) def run_ez(args): from tofu.ez.GUI.ezufo_launcher import main_qt main_qt(args) def get_ipython_shell(config=None): import IPython version = IPython.__version__ shell = None def cmp_versions(v1, v2): """Compare two version numbers and return cmp compatible result""" def normalize(v): return [int(x) for x in re.sub(r'(\.0+)*$', '', v).split(".")] n1 = normalize(v1) n2 = normalize(v2) return (n1 > n2) - (n1 < n2) if cmp_versions(version, '0.11') < 0: from IPython.Shell import IPShellEmbed shell = IPShellEmbed() elif cmp_versions(version, '1.0') < 0: from IPython.frontend.terminal.embed import \ InteractiveShellEmbed shell = InteractiveShellEmbed(config=config, banner1='') else: from IPython.terminal.embed import InteractiveShellEmbed shell = InteractiveShellEmbed(config=config, banner1='') return shell def run_shell(args): from tofu import reco shell = get_ipython_shell() shell() def run_find_large_spots(args): from tofu.find_large_spots import find_large_spots find_large_spots(args) def gui(args): try: from tofu import gui gui.main(args) except ImportError as e: LOG.error(str(e)) def run_flow(args): from tofu.flow.main import main as flow_main flow_main() def estimate(params): from tofu import reco center = reco.estimate_center(params) if params.verbose: out = '>>> Best axis of rotation: {}'.format(center) else: out = center print(out) def perf(args): from tofu import reco def measure(args): exec_times = [] total_times = [] for i in range(args.num_runs): start = time.time() exec_times.append(reco.tomo(args)) total_times.append(time.time() - start) exec_time = sum(exec_times) / len(exec_times) total_time = sum(total_times) / len(total_times) overhead = (total_time / exec_time - 1.0) * 100 input_bandwidth = args.width * args.height * num_projections * 4 / exec_time / 1024. / 1024. output_bandwidth = args.width * args.width * height * 4 / exec_time / 1024. / 1024. slice_bandwidth = args.height / exec_time # Four bytes of our output bandwidth constitute one slice pixel, for each # pixel we have to do roughly n * 6 floating point ops (2 mad, 1 add, 1 # interpolation) flops = output_bandwidth / 4 * 6 * num_projections / 1024 msg = ("width={:<6d} height={:<6d} n_proj={:<6d} " "exec={:.4f}s total={:.4f}s overhead={:.2f}% " "bandwidth_i={:.2f}MB/s bandwidth_o={:.2f}MB/s slices={:.2f}/s " "flops={:.2f}GFLOPs\n") sys.stdout.write(msg.format(args.width, args.height, args.number, exec_time, total_time, overhead, input_bandwidth, output_bandwidth, slice_bandwidth, flops)) sys.stdout.flush() args.projections = None args.sinograms = None args.dry_run = True for width in range(*args.width_range): for height in range(*args.height_range): for num_projections in range(*args.num_projection_range): args.width = width args.height = height args.number = num_projections measure(args) def main(): parser = argparse.ArgumentParser() parser.add_argument('--config', **config.SECTIONS['general']['config']) parser.add_argument('--version', action='version', version='%(prog)s {}'.format(__version__)) sino_params = ('flat-correction', 'sinos') reco_params = ('flat-correction', 'reconstruction') tomo_params = config.TOMO_PARAMS lamino_params = config.LAMINO_PARAMS gui_params = tomo_params + ('gui', ) cmd_parsers = [ ('init', init, (), "Create configuration file"), ('preprocess', run_preprocessing, config.PREPROC_PARAMS, "Run preprocessing"), ('flatcorrect', run_flat_correct, ('flat-correction',), "Run flat field correction"), ('sinos', run_sinos, sino_params, "Generate sinograms from projections"), ('tomo', run_tomo, tomo_params, "Run tomographic reconstruction"), ('lamino', run_lamino, lamino_params, "Run laminographic reconstruction"), ('reco', run_genreco, config.GEN_RECO_PARAMS, "Run general projection-based " "reconstruction for tomographic/" "laminographic cone/parallel beam"), ('gui', gui, tomo_params + ('gui',), "GUI for tomographic reconstruction"), ('flow', run_flow, (), "Visual flow creation"), ('ez', run_ez, (), "GUI for making ufo-kit data processing pipelines"), ('estimate', estimate, tomo_params + ('estimate',), "Estimate center of rotation"), ('perf', perf, tomo_params + ('perf',), "Check reconstruction performance"), ('interactive', run_shell, tomo_params, "Run interactive mode"), ('find-large-spots', run_find_large_spots, ('find-large-spots',), "Find large spots on images"), ] if sys.version < '3.7': subparsers = parser.add_subparsers(title="Commands", dest='commands') else: subparsers = parser.add_subparsers(title="Commands", dest='commands', required=True) for cmd, func, sections, text in cmd_parsers: cmd_params = config.Params(sections=sections) cmd_parser = subparsers.add_parser(cmd, help=text, formatter_class=argparse.ArgumentDefaultsHelpFormatter) cmd_parser = cmd_params.add_arguments(cmd_parser) cmd_parser.set_defaults(_func=func) args = config.parse_known_args(parser, subparser=True) log_level = logging.DEBUG if args.verbose else logging.INFO LOG.setLevel(log_level) stream_handler = logging.StreamHandler(sys.stdout) stream_handler.setFormatter(logging.Formatter('%(levelname)s: %(message)s')) LOG.addHandler(stream_handler) if args.log: file_handler = logging.FileHandler(args.log) file_handler.setFormatter(logging.Formatter('[%(asctime)s] %(name)s:%(levelname)s: %(message)s')) LOG.addHandler(file_handler) try: config.log_values(args) args._func(args) except RuntimeError as e: LOG.error(str(e)) sys.exit(1) if __name__ == '__main__': main() # vim: ft=python tofu-0.12.0/docs/000077500000000000000000000000001423713721100135005ustar00rootroot00000000000000tofu-0.12.0/docs/Makefile000066400000000000000000000011101423713721100151310ustar00rootroot00000000000000# Minimal makefile for Sphinx documentation # # You can set these variables from the command line. SPHINXOPTS = SPHINXBUILD = sphinx-build SOURCEDIR = source BUILDDIR = build # Put it first so that "make" without argument is like "make help". help: @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) .PHONY: help Makefile # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). %: Makefile @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)tofu-0.12.0/docs/source/000077500000000000000000000000001423713721100150005ustar00rootroot00000000000000tofu-0.12.0/docs/source/api/000077500000000000000000000000001423713721100155515ustar00rootroot00000000000000tofu-0.12.0/docs/source/api/genreco.rst000066400000000000000000000001211423713721100177170ustar00rootroot000000000000003D Reconstruction ================= .. automodule:: tofu.genreco :members: tofu-0.12.0/docs/source/api/preprocessing.rst000066400000000000000000000001151423713721100211630ustar00rootroot00000000000000Pre-processing ============== .. automodule:: tofu.preprocess :members: tofu-0.12.0/docs/source/api/util.rst000066400000000000000000000000751423713721100172620ustar00rootroot00000000000000Utilities ========= .. automodule:: tofu.util :members: tofu-0.12.0/docs/source/conf.py000066400000000000000000000133151423713721100163020ustar00rootroot00000000000000# -*- coding: utf-8 -*- # # Configuration file for the Sphinx documentation builder. # # This file does only contain a selection of the most common options. For a # full list see the documentation: # http://www.sphinx-doc.org/en/master/config # -- Path setup -------------------------------------------------------------- # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. # import os import sys sys.path.insert(0, os.path.abspath(os.path.join('..', '..'))) # -- Project information ----------------------------------------------------- project = 'Tofu' copyright = '2020, Tomas Farago' author = 'Tomas Farago' # The short X.Y version version = '' # The full version, including alpha/beta/rc tags release = '' # -- General configuration --------------------------------------------------- # If your documentation needs a minimal Sphinx version, state it here. # # needs_sphinx = '1.0' # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ 'sphinx.ext.autodoc', 'sphinx.ext.doctest', 'sphinx.ext.intersphinx', 'sphinx.ext.todo', 'sphinx.ext.coverage', 'sphinx.ext.imgmath', 'sphinx.ext.ifconfig', 'sphinx.ext.viewcode', 'sphinx.ext.githubpages', ] autodoc_mock_imports = ['gi'] # Add any paths that contain templates here, relative to this directory. templates_path = ['_templates'] # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: # # source_suffix = ['.rst', '.md'] source_suffix = '.rst' # The master toctree document. master_doc = 'index' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. # # This is also used if you do content translation via gettext catalogs. # Usually you set "language" from the command line for these cases. language = None # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. exclude_patterns = [] # The name of the Pygments (syntax highlighting) style to use. pygments_style = None # -- Options for HTML output ------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # html_theme = 'alabaster' # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. # # html_theme_options = {} # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". html_static_path = ['_static'] # Custom sidebar templates, must be a dictionary that maps document names # to template names. # # The default sidebars (for documents that don't match any pattern) are # defined by theme itself. Builtin themes are using these templates by # default: ``['localtoc.html', 'relations.html', 'sourcelink.html', # 'searchbox.html']``. # # html_sidebars = {} # -- Options for HTMLHelp output --------------------------------------------- # Output file base name for HTML help builder. htmlhelp_basename = 'Tofudoc' # -- Options for LaTeX output ------------------------------------------------ latex_elements = { # The paper size ('letterpaper' or 'a4paper'). # # 'papersize': 'letterpaper', # The font size ('10pt', '11pt' or '12pt'). # # 'pointsize': '10pt', # Additional stuff for the LaTeX preamble. # # 'preamble': '', # Latex figure (float) alignment # # 'figure_align': 'htbp', } # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ (master_doc, 'Tofu.tex', 'Tofu Documentation', 'Tomas Farago', 'manual'), ] # -- Options for manual page output ------------------------------------------ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). man_pages = [ (master_doc, 'tofu', 'Tofu Documentation', [author], 1) ] # -- Options for Texinfo output ---------------------------------------------- # Grouping the document tree into Texinfo files. List of tuples # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ (master_doc, 'Tofu', 'Tofu Documentation', author, 'Tofu', 'One line description of project.', 'Miscellaneous'), ] # -- Options for Epub output ------------------------------------------------- # Bibliographic Dublin Core info. epub_title = project # The unique identifier of the text. This can be a ISBN number # or the project homepage. # # epub_identifier = '' # A unique identification for the text. # # epub_uid = '' # A list of files that should not be packed into the epub file. epub_exclude_files = ['search.html'] # -- Extension configuration ------------------------------------------------- # -- Options for intersphinx extension --------------------------------------- # Example configuration for intersphinx: refer to the Python standard library. intersphinx_mapping = {'https://docs.python.org/': None} # -- Options for todo extension ---------------------------------------------- # If true, `todo` and `todoList` produce output, else they produce nothing. todo_include_todos = True tofu-0.12.0/docs/source/index.rst000066400000000000000000000012071423713721100166410ustar00rootroot00000000000000.. Tofu documentation master file, created by sphinx-quickstart on Fri Aug 14 17:29:07 2020. You can adapt this file completely to your liking, but it should at least contain the root `toctree` directive. Welcome to Tofu's documentation! ================================ .. toctree:: :maxdepth: 2 :caption: Contents: Application Programming Interface ================================= .. toctree:: :maxdepth: 2 api/preprocessing api/genreco api/util Usage ===== .. toctree:: :maxdepth: 2 usage/genreco usage/flow Indices and tables ================== * :ref:`genindex` * :ref:`modindex` * :ref:`search` tofu-0.12.0/docs/source/usage/000077500000000000000000000000001423713721100161045ustar00rootroot00000000000000tofu-0.12.0/docs/source/usage/flow.rst000066400000000000000000000270231423713721100176110ustar00rootroot00000000000000Flow: Visual Graph Creation =========================== You can use command ``tofu flow`` to start a graphical user interface in which UFO tasks are represented as nodes which you can connect together. Once you have created your flow you can execute it. Nodes ----- An operation on data is represented by a node in a flow. A node has inputs and outputs, which have data types. An input or output of a node is represented by a `port`, which is a circle on the left of the node in case of an input and on the right in case of output. Every port has a data type which is represented by color. There are two data types: - `UFO`: you can connect all UFO nodes together - `Array`: a numpy array which comes out from UFO's ``memory_out`` node and may be used to visualize the processing result by ``image_viewer`` A node may have properties (almost all UFO nodes do, e.g. ``path`` in the ``read`` node) which are listed and can be set inside the node. If you hover the mouse over a property field, a tooltip will be shown describing that property. When you right click on a node which holds properties, a context menu pops up and let's you choose which properties you want to be visible and which not. Some nodes, like ``general_backproject`` have many properties, many of which may be considered `expert` options which are not needed most of the times. By hiding these properties, you can avoid clutter. There is a pre-defined set of properties, which are shown by default. When you create a node in the scene, this setting is applied and you can check which properties are hidden by default by clicking on a node right after its creation. In case a node doesn't have properties, right click either doesn't take effect or pops up a context menu relevant for that node. E.g. ``image_viewer``'s context menu allows you to configure the viewer's behavior. .. double click A node might implement an action on a double click, e.g. ``read`` node opens a dialog allowing you to choose the data ``path``, ``image_viewer`` pops up an external image window which can be enlarged and put on another display for convenience. Current nodes which implement double clicks are: - ``Composite``: opens a new window with a scene displaying internal composite nodes - ``read``: opens a dialog which allows you to choose the input ``path`` - ``write``: opens a dialog which allows you to choose the output ``filename`` - ``image_viewer``: opens the image in a new window .. auto fill ``read`` node currently supports an `auto fill` option, which may be invoked via the main menu bar. The node sets its ``number`` property to the number of detected images found in the specified ``path``. UFO Nodes ~~~~~~~~~ An UFO node represents an `UFO task` and holds properties which are the Properties of this `UFO task`. Please check the `UFO Filters Reference `_ for the complete list of UFO tasks and their properties. When you create an UFO node, its properties are the default properties of the encapsulated UFO task. Composite Nodes ~~~~~~~~~~~~~~~ In order to reduce clutter, you can combine several nodes in a composite node (main menu's `Nodes->Create Composite`) and you can also nest composites, i.e. have a composite node and create another composite node with the first one inside. Internal nodes are listed as groups in the composite node in the scene and similarly to property nodes, you can show and hide different internal nodes from the listing. The input and output ports of a composite node are the ports of its internal nodes which are not connected at the time of composite node's creation. A double click on a composite node opens its internal nodes in a separate window, where you can edit their properties but you can't add new nodes or change connections. You can open this window also by pressing `Nodes->Edit Composite` in the main menu. In order to store a composite node for later usage, you can export it into a file via the main menu's `Nodes->Export Composite`. You can import composite node definitions by `Nodes->Import Composites`, which are then available in the flow scene's context menu in the `Composite` category. There are several pre-defined composite nodes available via the scene's context menu (category `Composite`), they are: - ``CFlatFieldCorrect`` encapsulates readers and averagers and the ``flat_field_correct`` node itself - ``CPhaseRetrieve`` encapsulates padding, fourier tranformation and the phase retrieval itself General Backproject ~~~~~~~~~~~~~~~~~~~ This is a versatile back projection node which can reconstruct tomographic, laminographic, parallel and cone beam data. It has one parameter which is not part of the UFO task, ``slice-memory-coeff``. This parameter sets the fraction (0 - 1) of a graphic card's memory which will be used to store the reconstructed volume. If you are working with graphic cards which have other processes running on them and these processes use a lot of memory, then you might need to reduce this parameter. Phase Retrieval ~~~~~~~~~~~~~~~ ``retrieve_phase`` node may have varying number of inputs in order to support multi-distance phase retrieval. You specify the number of inputs in a dialog when you create the node. If you specify more than one input, the retrieval method will be the multidistance contrast transfer function and the ``method`` field will be fixed to `ctf_multidistance`. In this case, fields ``distance-x`` and ``distance-y`` will be disabled. If you specify one input, you may choose different methods via the ``method`` field. In this case, you can either specify one value in the ``distance`` field, or specify separate distances for `x` and `y` directions via ``distance-x`` and ``distance-y`` fields (they take precedence over ``distance`` field in case they are both non-zero). Image Viewer ~~~~~~~~~~~~ ``image_viewer`` lets you display the results of your flow. It is composed of the image itself and three text boxes with sliders, which allow you to specify the image index shown, the black point and white point. In case only one image is input the first slider is hidden. Right click on the node opens a context menu which allows you to reset the black and white points (`Reset`), set them automatically (`Auto Levels`) and specify whether they should be automatically adjusted when new images are on input or left unchanged (`Auto Levels on New Image`). Double click opens the image in a new window by using the PyQtGraph_ library. In case a separate window is open, image index, black and white point settings can be set eigher in the flow node or in the window and they are reflected in both the node and the window. Flows ----- On right click in the flow scene a context menu will pop up and you will be able to add nodes. Then you can connect them by dragging a node's output port into another node's input port if those ports have the same data type, which are distinguished by port colors. By connecting node ports you create your flow which you may later execute. Every node in the scene must have a unique `caption`, so when you create a ``read`` node, the caption will be ``Read``, when you create another ``read`` node, the caption will be ``Read 2`` and so on. This is important for setting property links explained below. The roots of the flow in the scene must be UFO nodes and leaves may have `UFO` or `Array` type. It is not possible to go from `UFO` to `Array` and back to `UFO`, i.e. the `UFO` portion of the flow in the scene must be one contiguous component of the flow. There may be only one flow in the scene and it must be completely connected (there can't be disconnected ports, e.g. ``flat_field_correct``'s ``darks`` port). You can delete the current flow by pressing `Flow->New`, you can save a flow into a flow file (.flow) by `Flow->Save` and open such files by `Flow->Open`. Property Links -------------- A property of a node might depend on another node's property, e.g. the number of dimensions of an ``ifft`` node depends on the number of dimensions of the predecessing ``fft`` task. In order to reduce the number of properties you need to set, you can `link` properties together, i.e. when you set one node's property, all the linked nodes' properties will be updated (e.g. when you change the number of dimensions of an ``ifft`` node, the number of dimensions of the linked ``fft`` node will be updated as well. You can create property links in the `Property Links` window (open via main menu bar's `View` field). At the top of the window, there is a tree view of the nodes in the scene. Its items are the nodes in the flow scene, and in case there are composites, they are listed recursively. The last level of the view are the properties of the nodes in the flow. You can drag these properties into the list in the second half of the window to start creating links. If you drag a property to a new row or a row doesn't exist yet, it is automatically added. If you drag a property into an existing row (over an existing cell), it is appended to this row and a link is created. Links are allowed only for properties with compatible data types, e.g. you cannot link ``read``'s ``path`` (a string) to ``fft``'s ``dimensions`` (a number). Also keep in mind that nodes which are able to process batches have their fields which are responsible for receiving different batches (e.g. ``number`` of the ``memory_out`` node) have string data type (so that you can type `{region}` inside) Execution --------- Execution of the flow starts with executing the UFO part of the flow, and if there is a ``memory_out`` and subsequent nodes, they get the result of the UFO processing as the batches are finished (or just one batch if no batch-capable nodes are in the flow). You start it by invoking main menu bar's `Flow->run` action. You can abort the execution but invoking `Flow->abort`. Batch Processing ~~~~~~~~~~~~~~~~ Some nodes require a lot of GPU memory and they can't process all the input data at once (e.g. ``general_backproject``). Based on your system, they can split the work on their own and tell the execution mechanism to run multiple batches. If your system has multiple GPUs, ``tofu flow`` may create several batches and each of these batches may be executed on one or more cards in your system. Currently, only *one* batch processing task is allowed in the flow and only ``general_backproject`` supports batch processing. In case your flow contains a node which is able to produce batches, then your consumer nodes must be able to process batches and they must be notified of the fact that they will get more batches on input. Currently, ``write`` and ``memory_out`` support batches and this is how you set them up for it: - ``write``: ``filename`` must contain `{region}` somewhere in it, e.g. `slices-{region}.tif` - ``memory_out``: ``number`` field must be set to `{region}` The `{region}` template is then replaced by the current batch identifier provided by the producer node which is capable of batch processing, e.g. `slices-0.tif`, `slices-100.tif` and so on. If there is no node capable of producing batches, this is how you set them up for normal, non-batch processing: - ``write``: ``filename`` field set to normal file name, e.g. `slices.tif` - ``memory_out``: ``number`` field set to the number of input images Python Console -------------- Main menu's `View->Python Console` opens up a Python interpreter console with attribute ``scene`` set to the flow scene, which allows you to interact with the nodes programatically, see `qtpynodeeditor docs `_ more details on flow scene functionality. .. _PyQtGraph: http://www.pyqtgraph.org/ .. _qtpynodeeditor: tofu-0.12.0/docs/source/usage/genreco.rst000066400000000000000000000033461423713721100202660ustar00rootroot00000000000000General 3D Reconstruction ========================= You can use command ``tofu reco`` to reconstruct paralell/cone beam tomography/laminography data. The algorithm is filtered back projection for parallel beam data and `Feldkamp `_ approach for cone beam data. It always reconstructs 2D slices in the plane parallel to the beam direction. The third dimensions may be the vertical slice position (the default) but can also be one of the geometrical parameters in order to find their best values for the final reconstruction (see ``tofu reco --help`` and check the ``--z-parameter`` entry for possible values). Angular input values are in degrees. To reconstruct slices -100, 100 with the step size 0.5 around the center which is defined as 1008.5 from 1500 projections acquired over 180 degrees stored in ``projs.tif``, with rotation axis in pixel 951 one would do:: tofu reco --projections projs.tif --number 1500 --center-position-x 951 --overall-angle 180 --center-position-z 1008.5 --region=-100,100,0.5 --output slices.tif To scan the roll angle around -2, 2 degrees with step 0.1 degree, one can use the following command:: tofu reco --projections projs.tif --number 1500 --overall-angle 180 --center-position-x 951 --center-position-z 1008.5 --z-parameter detector-angle-y --region=-2,2,0.1 --output detector-angle-y-scan.tif --disable-projection-crop To scan the rotation axis region from pixel 940 to pixel 960 with step 0.5 pixels, (the ``center-position-x`` parameter), one can use:: tofu reco --projections projs.tif --number 1500 --overall-angle 180 --center-position-z 1008.5 --z-parameter center-position-x --region=940,960,0.5 --output center-position-x-scan.tif tofu-0.12.0/requirements-flow-tests.txt000066400000000000000000000000211423713721100201320ustar00rootroot00000000000000pytest pytest-qt tofu-0.12.0/requirements-flow.txt000066400000000000000000000001061423713721100167760ustar00rootroot00000000000000PyGObject imageio numpy networkx PyQt5 pyqtconsole xdg qtpynodeeditor tofu-0.12.0/requirements-guis.txt000066400000000000000000000001061423713721100167760ustar00rootroot00000000000000PyGObject imageio numpy networkx PyQt5 pyqtconsole xdg qtpynodeeditor tofu-0.12.0/setup.py000066400000000000000000000016711423713721100142670ustar00rootroot00000000000000from setuptools import setup, find_packages from tofu import __version__ setup( name='ufo-tofu', python_requires='>=3', version=__version__, author='Matthias Vogelgesang', author_email='matthias.vogelgesang@kit.edu', url='http://github.com/ufo-kit/tofu', license='LGPL', packages=find_packages(), package_data={'tofu': ['gui.ui', 'ez/GUI/default_settings.yaml'], 'tofu.flow': ['composites/*.cm', 'config.json']}, scripts=['bin/tofu'], exclude_package_data={'': ['README.rst']}, install_requires= [ 'PyGObject', 'imageio', 'numpy', 'networkx', 'PyQt5', 'pyqtconsole', 'xdg', 'qtpynodeeditor' ], description="A fast, versatile and user-friendly image "\ "processing toolkit for computed tomography", long_description=open('README.md').read(), long_description_content_type='text/markdown', ) tofu-0.12.0/tofu/000077500000000000000000000000001423713721100135255ustar00rootroot00000000000000tofu-0.12.0/tofu/__init__.py000066400000000000000000000000271423713721100156350ustar00rootroot00000000000000__version__ = '0.12.0' tofu-0.12.0/tofu/config.py000066400000000000000000000705311423713721100153520ustar00rootroot00000000000000import argparse import sys import logging import configparser as configparser from collections import OrderedDict from tofu.util import convert_filesize, restrict_value, tupleize, range_list LOG = logging.getLogger(__name__) NAME = "reco.conf" SECTIONS = OrderedDict() SECTIONS['general'] = { 'config': { 'default': NAME, 'type': str, 'help': "File name of configuration", 'metavar': 'FILE'}, 'verbose': { 'default': False, 'help': 'Verbose output', 'action': 'store_true'}, 'output': { 'default': 'result-%05i.tif', 'type': str, 'help': "Path to location or format-specified file path " "for storing reconstructed slices", 'metavar': 'PATH'}, 'output-bitdepth': { 'default': 32, 'type': restrict_value((0, None), dtype=int), 'help': "Bit depth of output, either 8, 16 or 32", 'metavar': 'BITDEPTH'}, 'output-minimum': { 'default': None, 'type': float, 'help': "Minimum value that maps to zero", 'metavar': 'MIN'}, 'output-maximum': { 'default': None, 'type': float, 'help': "Maximum input value that maps to largest output value", 'metavar': 'MAX'}, 'output-bytes-per-file': { 'default': '128g', 'type': convert_filesize, 'help': "Maximum bytes per file (0=single-image output, otherwise multi-image output)\ , 'k', 'm', 'g', 't' suffixes can be used", 'metavar': 'BYTESPERFILE'}, 'output-append': { 'default': False, 'action': 'store_true', 'help': 'Append images instead of overwriting existing files'}, 'log': { 'default': None, 'type': str, 'help': "File name of optional log", 'metavar': 'FILE'}, 'width': { 'default': None, 'type': restrict_value((0, None), dtype=int), 'help': "Input width"}} SECTIONS['reading'] = { 'y': { 'type': restrict_value((0, None), dtype=int), 'default': 0, 'help': 'Vertical coordinate from where to start reading the input image'}, 'height': { 'default': None, 'type': restrict_value((0, None), dtype=int), 'help': "Number of rows which will be read"}, 'bitdepth': { 'default': 32, 'type': restrict_value((0, None), dtype=int), 'help': "Bit depth of raw files"}, 'y-step': { 'type': restrict_value((0, None), dtype=int), 'default': 1, 'help': "Read every \"step\" row from the input"}, 'start': { 'type': restrict_value((0, None), dtype=int), 'default': 0, 'help': 'Offset to the first read file'}, 'number': { 'type': restrict_value((0, None), dtype=int), 'default': None, 'help': 'Number of files to read'}, 'step': { 'type': restrict_value((0, None), dtype=int), 'default': 1, 'help': 'Read every \"step\" file'}, 'resize': { 'type': restrict_value((0, None), dtype=int), 'default': None, 'help': 'Bin pixels before processing'}, 'retries': { 'type': restrict_value((0, None), dtype=int), 'default': 0, 'metavar': 'NUMBER', 'help': 'How many times to wait for new files'}, 'retry-timeout': { 'type': restrict_value((0, None), dtype=int), 'default': 0, 'metavar': 'TIME', 'help': 'How long to wait for new files per trial'}} SECTIONS['flat-correction'] = { 'projections': { 'default': None, 'type': str, 'help': "Location with projections", 'metavar': 'PATH'}, 'darks': { 'default': None, 'type': str, 'help': "Location with darks", 'metavar': 'PATH'}, 'dark-scale': { 'default': 1, 'type': float, 'help': "Scaling dark"}, 'reduction-mode': { 'default': "Average", 'type': str, 'help': "Flat-field correction options: Average (darks) or median (flats)"}, 'fix-nan-and-inf': { 'default': False, 'help': "Fix nan and inf", 'action': 'store_true'}, 'flats': { 'default': None, 'type': str, 'help': "Location with flats", 'metavar': 'PATH'}, 'flats2': { 'default': None, 'type': str, 'help': "Location with flats 2 for interpolation correction", 'metavar': 'PATH'}, 'flat-scale': { 'default': 1, 'type': float, 'help': "Scaling flat"}, 'absorptivity': { 'default': False, 'action': 'store_true', 'help': 'Do absorption correction'}} SECTIONS['retrieve-phase'] = { 'retrieval-method': { 'choices': ['tie', 'ctf', 'qp', 'qp2'], 'default': 'tie', 'help': "Phase retrieval method"}, 'energy': { 'default': None, 'type': float, 'help': "X-ray energy [keV]"}, 'propagation-distance': { 'default': None, 'type': tupleize(), 'help': ("Sample <-> detector distance (if one value, then use the same for x and y" "direction, otherwise first specifies x and second y direction) [m]")}, 'pixel-size': { 'default': 1e-6, 'type': float, 'help': "Pixel size [m]"}, 'regularization-rate': { 'default': 2, 'type': float, 'help': "Regularization rate (typical values between [2, 3])"}, 'delta': { 'default': None, 'type': float, 'help': "Real part of the complex refractive index of the material. " "If specified, phase retrieval returns projected thickness, " "if not, it returns phase"}, 'retrieval-padded-width': { 'default': 0, 'type': restrict_value((0, None), dtype=int), 'help': "Padded width used for phase retrieval"}, 'retrieval-padded-height': { 'default': 0, 'type': restrict_value((0, None), dtype=int), 'help': "Padded height used for phase retrieval"}, 'retrieval-padding-mode': { 'choices': ['none', 'clamp', 'clamp_to_edge', 'repeat', 'mirrored_repeat'], 'default': 'clamp_to_edge', 'help': "Padded values assignment"}, 'thresholding-rate': { 'default': 0.01, 'type': float, 'help': "Thresholding rate (typical values between [0.01, 0.1])"}, 'frequency-cutoff': { 'default': 1e30, 'type': float, 'help': "Phase retrieval frequency cutoff [rad]"}} SECTIONS['sinos'] = { 'pass-size': { 'type': restrict_value((0, None), dtype=int), 'default': 0, 'help': 'Number of sinograms to process per pass'}} SECTIONS['reconstruction'] = { 'sinograms': { 'default': None, 'type': str, 'help': "Location with sinograms", 'metavar': 'PATH'}, 'angle': { 'default': None, 'type': float, 'help': "Angle step between projections in radians"}, 'enable-tracing': { 'default': False, 'help': "Enable tracing and store result in .PID.json", 'action': 'store_true'}, 'remotes': { 'default': None, 'type': str, 'help': "Addresses to remote ufo-nodes", 'nargs': '+'}, 'projection-filter': { 'default': 'ramp-fromreal', 'type': str, 'help': "Projection filter", 'choices': ['none', 'ramp', 'ramp-fromreal', 'butterworth', 'faris-byer', 'bh3', 'hamming']}, 'projection-filter-cutoff': { 'default': 0.5, 'type': float, 'help': "Relative cutoff frequency"}, 'projection-padding-mode': { 'choices': ['none', 'clamp', 'clamp_to_edge', 'repeat', 'mirrored_repeat'], 'default': 'clamp_to_edge', 'help': "Padded values assignment"}} SECTIONS['tomographic-reconstruction'] = { 'axis': { 'default': None, 'type': float, 'help': "Axis position"}, 'dry-run': { 'default': False, 'help': "Reconstruct without writing data", 'action': 'store_true'}, 'offset': { 'default': 0.0, 'type': float, 'help': "Angle offset of first projection in radians"}, 'method': { 'default': 'fbp', 'type': str, 'help': "Reconstruction method", 'choices': ['fbp', 'dfi', 'sart', 'sirt', 'sbtv', 'asdpocs']}} SECTIONS['laminographic-reconstruction'] = { 'angle': { 'default': None, 'type': float, 'help': "Angle step between projections in radians"}, 'dry-run': { 'default': False, 'help': "Reconstruct without writing data", 'action': 'store_true'}, 'axis': { 'default': None, 'required': True, 'type': tupleize(num_items=2), 'help': "Axis position"}, 'x-region': { 'default': "0,-1,1", 'type': tupleize(num_items=3, conv=int), 'help': "X region as from,to,step"}, 'y-region': { 'default': "0,-1,1", 'type': tupleize(num_items=3, conv=int), 'help': "Y region as from,to,step"}, 'z': { 'default': 0, 'type': int, 'help': "Z coordinate of the reconstructed slice"}, 'z-parameter': { 'default': 'z', 'type': str, 'choices': ['z', 'x-center', 'lamino-angle', 'roll-angle'], 'help': "Parameter to vary along the reconstructed z-axis"}, 'region': { 'default': "0,-1,1", 'type': tupleize(num_items=3), 'help': "Z-axis parameter region as from,to,step"}, 'overall-angle': { 'default': None, 'type': float, 'help': "The total angle over which projections were taken in degrees"}, 'lamino-angle': { 'default': None, 'required': True, 'type': float, 'help': "The laminographic angle in degrees"}, 'roll-angle': { 'default': 0.0, 'type': float, 'help': "Sample angular misalignment to the side (roll) in degrees, positive angles mean\ clockwise misalignment"}, 'slices-per-device': { 'default': None, 'type': restrict_value((0, None), dtype=int), 'help': "Number of slices computed by one computing device"}, 'only-bp': { 'default': False, 'action': 'store_true', 'help': "Do only backprojection with no other processing steps"}, 'lamino-padding-mode': { 'choices': ['none', 'clamp', 'clamp_to_edge', 'repeat', 'mirrored_repeat'], 'default': 'clamp', 'help': "Padded values assignment for the filtered projection"}} SECTIONS['fbp'] = { 'crop-width': { 'default': None, 'type': restrict_value((0, None), dtype=int), 'help': "Width of final slice"}, 'projection-crop-after': { 'choices': ['filter', 'backprojection'], 'default': 'backprojection', 'help': "Whether to crop projections after filtering (can cause truncation " "artifacts) or after backprojection"}} SECTIONS['dfi'] = { 'oversampling': { 'default': None, 'type': restrict_value((0, None), dtype=int), 'help': "Oversample factor"}} SECTIONS['ir'] = { 'num-iterations': { 'default': 10, 'type': restrict_value((0, None), dtype=int), 'help': "Maximum number of iterations"}} SECTIONS['sart'] = { 'relaxation-factor': { 'default': 0.25, 'type': float, 'help': "Relaxation factor"}} SECTIONS['sbtv'] = { 'lambda': { 'default': 0.1, 'type': float, 'help': "Lambda"}, 'mu': { 'default': 0.5, 'type': float, 'help': "mu"}} SECTIONS['gui'] = { 'enable-cropping': { 'default': False, 'help': "Enable cropping width", 'action': 'store_true'}, 'show-2d': { 'default': False, 'help': "Show 2D slices with pyqtgraph", 'action': 'store_true'}, 'show-3d': { 'default': False, 'help': "Show 3D slices with pyqtgraph", 'action': 'store_true'}, 'last-dir': { 'default': '.', 'type': str, 'help': "Path of the last used directory", 'metavar': 'PATH'}, 'deg0': { 'default': '.', 'type': str, 'help': "Location with 0 deg projection", 'metavar': 'PATH'}, 'deg180': { 'default': '.', 'type': str, 'help': "Location with 180 deg projection", 'metavar': 'PATH'}, 'ffc-correction': { 'default': False, 'help': "Enable darks or flats correction", 'action': 'store_true'}, 'num-flats': { 'default': 0, 'type': int, 'help': "Number of flats for ffc correction."}} SECTIONS['estimate'] = { 'estimate-method': { 'type': str, 'default': 'correlation', 'help': 'Rotation axis estimation algorithm', 'choices': ['reconstruction', 'correlation']}} SECTIONS['perf'] = { 'num-runs': { 'default': 3, 'type': restrict_value((0, None), dtype=int), 'help': "Number of runs"}, 'width-range': { 'default': '1024', 'type': range_list, 'help': "Width or range of widths of generated projections"}, 'height-range': { 'default': '1024', 'type': range_list, 'help': "Height or range of heights of generated projections"}, 'num-projection-range': { 'default': '512', 'type': range_list, 'help': "Number or range of number of projections"}} SECTIONS['preprocess'] = { 'transpose-input': { 'default': False, 'action': 'store_true', 'help': "Transpose projections before they are backprojected (after phase retrieval)"}, 'projection-filter': { 'default': 'ramp-fromreal', 'type': str, 'help': "Projection filter", 'choices': ['none', 'ramp', 'ramp-fromreal', 'butterworth', 'faris-byer', 'bh3', 'hamming']}, 'projection-filter-cutoff': { 'default': 0.5, 'type': float, 'help': "Relative cutoff frequency"}, 'projection-filter-scale': { 'default': 1., 'type': float, 'help': "Multiplicative factor of the projection filter"}, 'projection-padding-mode': { 'choices': ['none', 'clamp', 'clamp_to_edge', 'repeat', 'mirrored_repeat'], 'default': 'clamp_to_edge', 'help': "Padded values assignment"}, 'projection-crop-after': { 'choices': ['filter', 'backprojection'], 'default': 'backprojection', 'help': "Whether to crop projections after filtering (can cause truncation " "artifacts) or after backprojection"}} SECTIONS['cone-beam-weight'] = { 'source-position-y': { 'default': "-Inf", 'type': tupleize(dtype=list), 'help': "Y source position (along beam direction) in global coordinates [pixels]"}, 'detector-position-y': { 'default': "0", 'type': tupleize(dtype=list), 'help': "Y detector position (along beam direction) in global coordinates [pixels]"}, 'center-position-x': { 'default': None, 'type': tupleize(), 'help': "X rotation axis position on a projection"}, 'center-position-z': { 'default': None, 'type': tupleize(), 'help': "Z rotation axis position on a projection"}, 'axis-angle-x': { 'default': "0", 'type': tupleize(dtype=list), 'help': "Rotation axis rotation around the x axis" "(laminographic angle, 0 = tomography) [deg]"}} SECTIONS['general-reconstruction'] = { 'enable-tracing': { 'default': False, 'help': "Enable tracing and store result in .PID.json", 'action': 'store_true'}, 'disable-cone-beam-weight': { 'default': False, 'action': 'store_true', 'help': "Disable cone beam weighting"}, 'slice-memory-coeff': { 'default': 0.8, 'type': restrict_value((0.01, 0.95)), 'help': "Portion of the GPU memory used for slices (from 0.01 to 0.9) [fraction]. " "The total amount of consumed memory will be larger depending on the " "complexity of the graph. In case of OpenCL memory allocation errors, " "try reducing this value."}, 'num-gpu-threads': { 'default': 1, 'type': restrict_value((1, None), dtype=int), 'help': "Number of parallel reconstruction threads on one GPU"}, 'disable-projection-crop': { 'default': False, 'action': 'store_true', 'help': "Disable automatic cropping of projections computed from volume region"}, 'dry-run': { 'default': False, 'help': "Reconstruct without reading or writing data", 'action': 'store_true'}, 'data-splitting-policy': { 'default': 'one', 'type': str, 'help': "'one': one GPU should process as many slices as possible, " "'many': slices should be spread across as many GPUs as possible", 'choices': ['one', 'many']}, 'projection-margin': { 'default': 0, 'type': restrict_value((0, None), dtype=int), 'help': "By optimization of the read projection region, the read region will be " "[y - margin, y + height + margin]"}, 'slices-per-device': { 'default': None, 'type': restrict_value((0, None), dtype=int), 'help': "Number of slices computed by one computing device"}, 'gpus': { 'default': None, 'nargs': '+', 'type': int, 'help': "GPUs with these indices will be used (0-based)"}, 'burst': { 'default': None, 'type': restrict_value((0, None), dtype=int), 'help': "Number of projections processed per kernel invocation"}, 'x-region': { 'default': "0,-1,1", 'type': tupleize(num_items=3), 'help': "x region as from,to,step"}, 'y-region': { 'default': "0,-1,1", 'type': tupleize(num_items=3), 'help': "y region as from,to,step"}, 'z': { 'default': 0, 'type': int, 'help': "z coordinate of the reconstructed slice"}, 'z-parameter': { 'default': 'z', 'type': str, 'choices': ['axis-angle-x', 'axis-angle-y', 'axis-angle-z', 'volume-angle-x', 'volume-angle-y', 'volume-angle-z', 'detector-angle-x', 'detector-angle-y', 'detector-angle-z', 'detector-position-x', 'detector-position-y', 'detector-position-z', 'source-position-x', 'source-position-y', 'source-position-z', 'center-position-x', 'center-position-z', 'z'], 'help': "Parameter to vary along the reconstructed z-axis"}, 'region': { 'default': "0,1,1", 'type': tupleize(num_items=3), 'help': "z axis parameter region as from,to,step"}, 'source-position-x': { 'default': "0", 'type': tupleize(dtype=list), 'help': "X source position (horizontal) in global coordinates [pixels]"}, 'source-position-z': { 'default': "0", 'type': tupleize(dtype=list), 'help': "Z source position (vertical) in global coordinates [pixels]"}, 'detector-position-x': { 'default': "0", 'type': tupleize(dtype=list), 'help': "X detector position (horizontal) in global coordinates [pixels]"}, 'detector-position-z': { 'default': "0", 'type': tupleize(dtype=list), 'help': "Z detector position (vertical) in global coordinates [pixels]"}, 'detector-angle-x': { 'default': "0", 'type': tupleize(dtype=list), 'help': "Detector rotation around the x axis (horizontal) [deg]"}, 'detector-angle-y': { 'default': "0", 'type': tupleize(dtype=list), 'help': "Detector rotation around the y axis (along beam direction) [deg]"}, 'detector-angle-z': { 'default': "0", 'type': tupleize(dtype=list), 'help': "Detector rotation around the z axis (vertical) [deg]"}, 'axis-angle-y': { 'default': "0", 'type': tupleize(dtype=list), 'help': "Rotation axis rotation around the y axis (along beam direction) [deg]"}, 'axis-angle-z': { 'default': "0", 'type': tupleize(dtype=list), 'help': "Rotation axis rotation around the z axis (vertical) [deg]"}, 'volume-angle-x': { 'default': "0", 'type': tupleize(dtype=list), 'help': "Volume rotation around the x axis (horizontal) [deg]"}, 'volume-angle-y': { 'default': "0", 'type': tupleize(dtype=list), 'help': "Volume rotation around the y axis (along beam direction) [deg]"}, 'volume-angle-z': { 'default': "0", 'type': tupleize(dtype=list), 'help': "Volume rotation around the z axis (vertical) [deg]"}, 'compute-type': { 'default': 'float', 'type': str, 'help': "Data type for performing kernel math operations", 'choices': ['half', 'float', 'double']}, 'result-type': { 'default': 'float', 'type': str, 'help': "Data type for storing the intermediate gray value for a voxel " "from various rotation angles", 'choices': ['half', 'float', 'double']}, 'store-type': { 'default': 'float', 'type': str, 'help': "Data type of the output volume", 'choices': ['half', 'float', 'double', 'uchar', 'ushort', 'uint']}, 'overall-angle': { 'default': None, 'type': float, 'help': "The total angle over which projections were taken in degrees"}, 'genreco-padding-mode': { 'choices': ['none', 'clamp', 'clamp_to_edge', 'repeat', 'mirrored_repeat'], 'default': 'clamp', 'help': "Padded values assignment for the filtered projection"}, 'slice-gray-map': { 'default': "0,0", 'type': tupleize(num_items=2, conv=float), 'help': "Minimum and maximum gray value mapping if store-type is integer-based"}} SECTIONS['find-large-spots'] = { 'images': { 'default': None, 'type': str, 'help': "Location with input images", 'metavar': 'PATH'}, 'gauss-sigma': { 'default': 0.0, 'type': float, 'help': "Gaussian sigma for removing low frequencies (filter will be 1 - gauss window)"}, 'blurred-output': { 'default': None, 'type': str, 'help': "Path where to store the blurred input"}, 'spot-threshold': { 'default': 0.0, 'type': float, 'help': "Pixels with grey value larger than this are considered as spots"}, 'spot-threshold-mode': { 'default': 'absolute', 'type': str, 'help': "Pixels must be either \"below\", \"above\" the spot threshold, or \ their \"absolute\" value can be compared", 'choices': ['below', 'above', 'absolute']}, 'grow-threshold': { 'default': 0.0, 'type': float, 'help': "Spot growing threshold, if 0 it will be set to FWTM times noise standard deviation"}, 'find-large-spots-padding-mode': { 'choices': ['none', 'clamp', 'clamp_to_edge', 'repeat', 'mirrored_repeat'], 'default': 'repeat', 'help': "Padded values assignment for the filtered input image"}, } TOMO_PARAMS = ('flat-correction', 'reconstruction', 'tomographic-reconstruction', 'fbp', 'dfi', 'ir', 'sart', 'sbtv') PREPROC_PARAMS = ('preprocess', 'cone-beam-weight', 'flat-correction', 'retrieve-phase') LAMINO_PARAMS = PREPROC_PARAMS + ('laminographic-reconstruction',) GEN_RECO_PARAMS = PREPROC_PARAMS + ('general-reconstruction',) NICE_NAMES = ('General', 'Input', 'Flat field correction', 'Phase retrieval', 'Sinogram generation', 'General reconstruction', 'Tomographic reconstruction', 'Laminographic reconstruction', 'Filtered backprojection', 'Direct Fourier Inversion', 'Iterative reconstruction', 'SART', 'SBTV', 'GUI settings', 'Estimation', 'Performance', 'Preprocess', 'Cone beam weight', 'General reconstruction', 'Find large spots') def get_config_name(): """Get the command line --config option.""" name = '' for i, arg in enumerate(sys.argv): if arg.startswith('--config'): if arg == '--config': return sys.argv[i + 1] else: name = sys.argv[i].split('--config')[1] if name[0] == '=': name = name[1:] return name return name def parse_known_args(parser, subparser=False): """ Parse arguments from file and then override by the ones specified on the command line. Use *parser* for parsing and is *subparser* is True take into account that there is a value on the command line specifying the subparser. """ if len(sys.argv) > 1: subparser_value = [sys.argv[1]] if subparser else [] config_values = config_to_list(config_name=get_config_name()) values = subparser_value + config_values + sys.argv[1:] args = None if config_values: args = parser.parse_known_args(args=subparser_value + config_values)[0] parser.parse_args(args=sys.argv[1:], namespace=args) else: values = "" return parser.parse_known_args(values)[0] def config_to_list(config_name=''): """ Read arguments from config file and convert them to a list of keys and values as sys.argv does when they are specified on the command line. *config_name* is the file name of the config file. """ result = [] config = configparser.ConfigParser() if not config.read([config_name]): return [] for section in SECTIONS: for name, opts in ((n, o) for n, o in list(SECTIONS[section].items()) if config.has_option(section, n)): value = config.get(section, name) if value != '' and value != 'None': action = opts.get('action', None) if action == 'store_true' and value == 'True': # Only the key is on the command line for this action result.append('--{}'.format(name)) if not action == 'store_true': if opts.get('nargs', None) == '+': result.append('--{}'.format(name)) result.extend((v.strip() for v in value.split(','))) else: result.append('--{}={}'.format(name, value)) return result class Params(object): def __init__(self, sections=()): self.sections = sections + ('general', 'reading') def add_parser_args(self, parser): for section in self.sections: for name in sorted(SECTIONS[section]): opts = SECTIONS[section][name] parser.add_argument('--{}'.format(name), **opts) def add_arguments(self, parser): self.add_parser_args(parser) return parser def get_defaults(self): parser = argparse.ArgumentParser() self.add_arguments(parser) return parser.parse_args('') def write(config_file, args=None, sections=None): """ Write *config_file* with values from *args* if they are specified, otherwise use the defaults. If *sections* are specified, write values from *args* only to those sections, use the defaults on the remaining ones. """ config = configparser.ConfigParser() for section in SECTIONS: config.add_section(section) for name, opts in list(SECTIONS[section].items()): if args and sections and section in sections and hasattr(args, name.replace('-', '_')): value = getattr(args, name.replace('-', '_')) if isinstance(value, list): value = ', '.join(value) else: value = opts['default'] if opts['default'] != None else '' prefix = '# ' if value == '' else '' if name != 'config': config.set(section, prefix + name, value) with open(config_file, 'wb') as f: config.write(f) def log_values(args): """Log all values set in the args namespace. Arguments are grouped according to their section and logged alphabetically using the DEBUG log level thus --verbose is required. """ args = args.__dict__ for section, name in zip(SECTIONS, NICE_NAMES): entries = sorted((k for k in list(args.keys()) if k.replace('_', '-') in SECTIONS[section])) if entries: LOG.debug(name) for entry in entries: value = args[entry] if args[entry] is not None else "-" LOG.debug(" {:<16} {}".format(entry, value)) tofu-0.12.0/tofu/ez/000077500000000000000000000000001423713721100141435ustar00rootroot00000000000000tofu-0.12.0/tofu/ez/GUI/000077500000000000000000000000001423713721100145675ustar00rootroot00000000000000tofu-0.12.0/tofu/ez/GUI/Advanced/000077500000000000000000000000001423713721100162745ustar00rootroot00000000000000tofu-0.12.0/tofu/ez/GUI/Advanced/__init__.py000066400000000000000000000000001423713721100203730ustar00rootroot00000000000000tofu-0.12.0/tofu/ez/GUI/Advanced/advanced.py000066400000000000000000000141201423713721100204110ustar00rootroot00000000000000import logging from PyQt5.QtWidgets import QGridLayout, QLabel, QGroupBox, QLineEdit import tofu.ez.params as parameters LOG = logging.getLogger(__name__) class AdvancedGroup(QGroupBox): """ Advanced Tofu Reco settings """ def __init__(self): super().__init__() self.setTitle("Advanced TOFU Reconstruction Settings") self.setStyleSheet("QGroupBox {color: green;}") # LAMINO self.lamino_group = QGroupBox("Extended Settings of Reconstruction Algorithms") self.lamino_group.clicked.connect(self.set_lamino_group) self.lamino_angle_label = QLabel("Laminographic angle ") self.lamino_angle_entry = QLineEdit() self.lamino_angle_entry.editingFinished.connect(self.set_lamino_angle) self.overall_rotation_label = QLabel("Overall rotation range about CT Z-axis") self.overall_rotation_entry = QLineEdit() self.overall_rotation_entry.editingFinished.connect(self.set_overall_rotation) self.center_position_z_label = QLabel("Center Position Z ") self.center_position_z_entry = QLineEdit() self.center_position_z_entry.editingFinished.connect(self.set_center_position_z) self.axis_rotation_y_label = QLabel( "Sample rotation about the beam Y-axis " ) self.axis_rotation_y_entry = QLineEdit() self.axis_rotation_y_entry.editingFinished.connect(self.set_rotation_about_beam) # AUXILIARY FFC self.dark_scale_label = QLabel("Dark scale ") self.dark_scale_entry = QLineEdit() self.dark_scale_entry.editingFinished.connect(self.set_dark_scale) self.flat_scale_label = QLabel("Flat scale ") self.flat_scale_entry = QLineEdit() self.flat_scale_entry.editingFinished.connect(self.set_flat_scale) self.set_layout() def set_layout(self): layout = QGridLayout() self.lamino_group.setCheckable(True) self.lamino_group.setChecked(False) lamino_layout = QGridLayout() lamino_layout.addWidget(self.lamino_angle_label, 0, 0) lamino_layout.addWidget(self.lamino_angle_entry, 0, 1) lamino_layout.addWidget(self.overall_rotation_label, 1, 0) lamino_layout.addWidget(self.overall_rotation_entry, 1, 1) lamino_layout.addWidget(self.center_position_z_label, 2, 0) lamino_layout.addWidget(self.center_position_z_entry, 2, 1) lamino_layout.addWidget(self.axis_rotation_y_label, 3, 0) lamino_layout.addWidget(self.axis_rotation_y_entry, 3, 1) self.lamino_group.setLayout(lamino_layout) aux_group = QGroupBox("Auxiliary FFC Settings") aux_group.setCheckable(True) aux_group.setChecked(False) aux_layout = QGridLayout() aux_layout.addWidget(self.dark_scale_label, 0, 0) aux_layout.addWidget(self.dark_scale_entry, 0, 1) aux_layout.addWidget(self.flat_scale_label, 1, 0) aux_layout.addWidget(self.flat_scale_entry, 1, 1) aux_group.setLayout(aux_layout) layout.addWidget(self.lamino_group) layout.addWidget(aux_group) self.setLayout(layout) def init_values(self): self.lamino_group.setChecked(False) parameters.params['advanced_advtofu_extended_settings'] = False self.lamino_angle_entry.setText("30") parameters.params['advanced_advtofu_lamino_angle'] = 30 self.overall_rotation_entry.setText("360") parameters.params['advanced_adv_tofu_z_axis_rotation'] = 360 self.center_position_z_entry.setText("") parameters.params['advanced_advtofu_center_position_z'] = "" self.axis_rotation_y_entry.setText("") parameters.params['advanced_advtofu_y_axis_rotation'] = "" self.dark_scale_entry.setText("") parameters.params['advanced_advtofu_aux_ffc_dark_scale'] = "" self.flat_scale_entry.setText("") parameters.params['advanced_advtofu_aux_ffc_flat_scale'] = "" def set_values_from_params(self): self.lamino_group.setChecked(parameters.params['advanced_advtofu_extended_settings']) self.lamino_angle_entry.setText(str(parameters.params['advanced_advtofu_lamino_angle'])) self.overall_rotation_entry.setText(str(parameters.params['advanced_adv_tofu_z_axis_rotation'])) self.center_position_z_entry.setText(str(parameters.params['advanced_advtofu_center_position_z'])) self.axis_rotation_y_entry.setText(str(parameters.params['advanced_advtofu_y_axis_rotation'])) self.dark_scale_entry.setText(str(parameters.params['advanced_advtofu_aux_ffc_dark_scale'])) self.flat_scale_entry.setText(str(parameters.params['advanced_advtofu_aux_ffc_flat_scale'])) def set_lamino_group(self): LOG.debug("Lamino: " + str(self.lamino_group.isChecked())) parameters.params['advanced_advtofu_extended_settings'] = bool(self.lamino_group.isChecked()) def set_lamino_angle(self): LOG.debug(self.lamino_angle_entry.text()) parameters.params['advanced_advtofu_lamino_angle'] = str(self.lamino_angle_entry.text()) def set_overall_rotation(self): LOG.debug(self.overall_rotation_entry.text()) parameters.params['advanced_adv_tofu_z_axis_rotation'] = str(self.overall_rotation_entry.text()) def set_center_position_z(self): LOG.debug(self.center_position_z_entry.text()) parameters.params['advanced_advtofu_center_position_z'] = str(self.center_position_z_entry.text()) def set_rotation_about_beam(self): LOG.debug(self.axis_rotation_y_entry.text()) parameters.params['advanced_advtofu_y_axis_rotation'] = str(self.axis_rotation_y_entry.text()) def set_dark_scale(self): LOG.debug(self.dark_scale_entry.text()) parameters.params['advanced_advtofu_aux_ffc_dark_scale'] = str(self.dark_scale_entry.text()) def set_flat_scale(self): LOG.debug(self.flat_scale_entry.text()) parameters.params['advanced_advtofu_aux_ffc_flat_scale'] = str(self.flat_scale_entry.text()) tofu-0.12.0/tofu/ez/GUI/Advanced/ffc.py000066400000000000000000000123761423713721100174150ustar00rootroot00000000000000import logging from PyQt5.QtWidgets import ( QGridLayout, QLabel, QGroupBox, QLineEdit, QCheckBox, QRadioButton, QHBoxLayout, ) import tofu.ez.params as parameters LOG = logging.getLogger(__name__) class FFCGroup(QGroupBox): """ Flat Field Correction Settings """ def __init__(self): super().__init__() self.setTitle("Flat Field Correction") self.setStyleSheet("QGroupBox {color: indigo;}") self.method_label = QLabel("Method:") self.average_rButton = QRadioButton("Average") self.average_rButton.clicked.connect(self.set_method) self.ssim_rButton = QRadioButton("SSIM") self.ssim_rButton.clicked.connect(self.set_method) self.eigen_rButton = QRadioButton("Eigen") self.eigen_rButton.clicked.connect(self.set_method) self.enable_sinFFC_checkbox = QCheckBox( "Use Smart Intensity Normalization Flat Field Correction" ) self.enable_sinFFC_checkbox.stateChanged.connect(self.set_sinFFC) self.eigen_pco_repetitions_label = QLabel("Eigen PCO Repetitions") self.eigen_pco_repetitions_entry = QLineEdit() self.eigen_pco_repetitions_entry.editingFinished.connect(self.set_pcoReps) self.eigen_pco_downsample_label = QLabel("Eigen PCO Downsample") self.eigen_pco_downsample_entry = QLineEdit() self.eigen_pco_downsample_entry.editingFinished.connect(self.set_pcoDowns) self.downsample_label = QLabel("Downsample") self.downsample_entry = QLineEdit() self.downsample_entry.editingFinished.connect(self.set_downsample) self.set_layout() def set_layout(self): layout = QGridLayout() rbutton_layout = QHBoxLayout() rbutton_layout.addWidget(self.method_label) rbutton_layout.addWidget(self.eigen_rButton) rbutton_layout.addWidget(self.average_rButton) rbutton_layout.addWidget(self.ssim_rButton) layout.addWidget(self.enable_sinFFC_checkbox, 0, 0) layout.addItem(rbutton_layout, 1, 0, 1, 2) layout.addWidget(self.eigen_pco_repetitions_label, 2, 0) layout.addWidget(self.eigen_pco_repetitions_entry, 2, 1) layout.addWidget(self.eigen_pco_downsample_label, 3, 0) layout.addWidget(self.eigen_pco_downsample_entry, 3, 1) layout.addWidget(self.downsample_label, 4, 0) layout.addWidget(self.downsample_entry, 4, 1) self.setLayout(layout) def init_values(self): self.eigen_rButton.setChecked(True) self.average_rButton.setChecked(False) self.ssim_rButton.setChecked(False) parameters.params['advanced_ffc_method'] = "eigen" self.enable_sinFFC_checkbox.setChecked(False) self.eigen_pco_repetitions_entry.setText("4") self.eigen_pco_downsample_entry.setText("2") self.downsample_entry.setText("4") def set_values_from_params(self): self.enable_sinFFC_checkbox.setChecked(parameters.params['advanced_ffc_sinFFC']) self.set_method_from_params() self.eigen_pco_repetitions_entry.setText(str(parameters.params['advanced_ffc_eigen_pco_reps'])) self.eigen_pco_downsample_entry.setText(str(parameters.params['advanced_ffc_eigen_pco_downsample'])) self.downsample_entry.setText(str(parameters.params['advanced_ffc_downsample'])) def set_sinFFC(self): logging.debug("sinFFC: " + str(self.enable_sinFFC_checkbox.isChecked())) parameters.params['advanced_ffc_sinFFC'] = bool(self.enable_sinFFC_checkbox.isChecked()) def set_pcoReps(self): logging.debug("PCO Reps: " + str(self.eigen_pco_repetitions_entry.text())) parameters.params['advanced_ffc_eigen_pco_reps'] = str(self.eigen_pco_repetitions_entry.text()) def set_pcoDowns(self): logging.debug("PCO Downsample: " + str(self.eigen_pco_downsample_entry.text())) parameters.params['advanced_ffc_eigen_pco_downsample'] = str(self.eigen_pco_downsample_entry.text()) def set_downsample(self): logging.debug("Downsample: " + str(self.downsample_entry.text())) parameters.params['advanced_ffc_downsample'] = str(self.downsample_entry.text()) def set_method(self): if self.eigen_rButton.isChecked(): logging.debug("Method: Eigen") parameters.params['advanced_ffc_method'] = "eigen" elif self.average_rButton.isChecked(): logging.debug("Method: Average") parameters.params['advanced_ffc_method'] = "average" elif self.ssim_rButton.isChecked(): logging.debug("Method: SSIM") parameters.params['advanced_ffc_method'] = "ssim" def set_method_from_params(self): if parameters.params['advanced_ffc_method'] == 1: self.eigen_rButton.setChecked(True) self.average_rButton.setChecked(False) self.ssim_rButton.setChecked(False) elif parameters.params['advanced_ffc_method'] == 2: self.eigen_rButton.setChecked(False) self.average_rButton.setChecked(True) self.ssim_rButton.setChecked(False) elif parameters.params['advanced_ffc_method'] == 3: self.eigen_rButton.setChecked(False) self.average_rButton.setChecked(False) self.ssim_rButton.setChecked(True) tofu-0.12.0/tofu/ez/GUI/Advanced/nlmdn.py000066400000000000000000000403121423713721100177560ustar00rootroot00000000000000import logging import os from shutil import rmtree from PyQt5.QtWidgets import ( QGridLayout, QLabel, QGroupBox, QLineEdit, QCheckBox, QPushButton, QFileDialog, QMessageBox, ) from PyQt5.QtCore import Qt import tofu.ez.params as parameters from tofu.ez.main_nlm import main_tk LOG = logging.getLogger(__name__) class NLMDNGroup(QGroupBox): """ Non-local means de-noising settings """ def __init__(self): super().__init__() self.setTitle("Non-local-means Denoising") self.setStyleSheet("QGroupBox {color: royalblue;}") self.apply_to_reco_checkbox = QCheckBox("Automatically apply NLMDN to reconstructed slices") self.apply_to_reco_checkbox.stateChanged.connect(self.set_apply_to_reco) self.input_dir_button = QPushButton("Select input directory") self.input_dir_button.clicked.connect(self.set_indir_button) self.select_img_button = QPushButton("Select one image") self.select_img_button.clicked.connect(self.select_image) self.input_dir_entry = QLineEdit() self.input_dir_entry.editingFinished.connect(self.set_indir_entry) self.output_dir_button = QPushButton("Select output directory or filename pattern") self.output_dir_button.clicked.connect(self.set_outdir_button) self.save_bigtif_checkbox = QCheckBox("Save in bigtif container") self.save_bigtif_checkbox.clicked.connect(self.set_save_bigtif) self.output_dir_entry = QLineEdit() self.output_dir_entry.editingFinished.connect(self.set_outdir_entry) self.similarity_radius_label = QLabel("Radius for similarity search") self.similarity_radius_entry = QLineEdit() self.similarity_radius_entry.editingFinished.connect(self.set_rad_sim_entry) self.patch_radius_label = QLabel("Radius of patches") self.patch_radius_entry = QLineEdit() self.patch_radius_entry.editingFinished.connect(self.set_rad_patch_entry) self.smoothing_label = QLabel("Smoothing control parameter") self.smoothing_entry = QLineEdit() self.smoothing_entry.editingFinished.connect(self.set_smoothing_entry) self.noise_std_label = QLabel("Noise standard deviation") self.noise_std_entry = QLineEdit() self.noise_std_entry.editingFinished.connect(self.set_noise_entry) self.window_label = QLabel("Window (optional)") self.window_entry = QLineEdit() self.window_entry.editingFinished.connect(self.set_window_entry) self.fast_checkbox = QCheckBox("Fast") self.fast_checkbox.clicked.connect(self.set_fast_checkbox) self.sigma_checkbox = QCheckBox("Estimate sigma") self.sigma_checkbox.clicked.connect(self.set_sigma_checkbox) self.help_button = QPushButton("Help") self.help_button.clicked.connect(self.help_button_pressed) self.delete_button = QPushButton("Delete reco dir") self.delete_button.clicked.connect(self.delete_button_pressed) self.dry_button = QPushButton("Dry run") self.dry_button.clicked.connect(self.dry_button_pressed) self.apply_button = QPushButton("Apply filter") self.apply_button.clicked.connect(self.apply_button_pressed) # self.apply_button.setStyleSheet("color:royalblue; font-weight: bold;") self.set_layout() def set_layout(self): layout = QGridLayout() layout.addWidget(self.apply_to_reco_checkbox, 0, 0, 1, 1) layout.addWidget(self.input_dir_button, 1, 0, 1, 2) layout.addWidget(self.select_img_button, 1, 2, 1, 2) layout.addWidget(self.input_dir_entry, 2, 0, 1, 4) layout.addWidget(self.output_dir_button, 3, 0, 1, 2) layout.addWidget(self.save_bigtif_checkbox, 3, 2, 1, 2, Qt.AlignCenter) layout.addWidget(self.output_dir_entry, 4, 0, 1, 4) layout.addWidget(self.similarity_radius_label, 5, 0, 1, 2) layout.addWidget(self.similarity_radius_entry, 5, 2, 1, 2) layout.addWidget(self.patch_radius_label, 6, 0, 1, 2) layout.addWidget(self.patch_radius_entry, 6, 2, 1, 2) layout.addWidget(self.smoothing_label, 7, 0, 1, 2) layout.addWidget(self.smoothing_entry, 7, 2, 1, 2) layout.addWidget(self.noise_std_label, 8, 0, 1, 2) layout.addWidget(self.noise_std_entry, 8, 2, 1, 2) layout.addWidget(self.window_label, 9, 0, 1, 2) layout.addWidget(self.window_entry, 9, 2, 1, 2) layout.addWidget(self.fast_checkbox, 10, 0, 1, 2, Qt.AlignCenter) layout.addWidget(self.sigma_checkbox, 10, 2, 1, 2, Qt.AlignCenter) layout.addWidget(self.help_button, 11, 0, 1, 1) layout.addWidget(self.delete_button, 11, 1) layout.addWidget(self.dry_button, 11, 2) layout.addWidget(self.apply_button, 11, 3) self.setLayout(layout) def init_values(self): self.apply_to_reco_checkbox.setChecked(False) parameters.params['advanced_nlmdn_apply_after_reco'] = False self.input_dir_entry.setText(os.getcwd()) parameters.params['advanced_nlmdn_input_dir'] = os.getcwd() self.output_dir_entry.setText(os.getcwd() + '-nlmfilt') parameters.params['advanced_nlmdn_output_dir'] = os.getcwd() + '-nlmfilt' self.e_bigtif = False parameters.params['advanced_nlmdn_save_bigtiff'] = False self.similarity_radius_entry.setText("10") self.patch_radius_entry.setText("3") self.smoothing_entry.setText("0.0") self.noise_std_entry.setText("0.0") self.window_entry.setText("0.0") self.fast_checkbox.setChecked(True) self.e_fast = True self.sigma_checkbox.setChecked(False) self.e_sig = False self.e_dryrun = False def set_values_from_params(self): self.apply_to_reco_checkbox.setChecked(bool(parameters.params['advanced_nlmdn_apply_after_reco'])) self.input_dir_entry.setText(str(parameters.params['advanced_nlmdn_input_dir'])) self.output_dir_entry.setText(str(parameters.params['advanced_nlmdn_output_dir'])) self.save_bigtif_checkbox.setChecked(bool(parameters.params['advanced_nlmdn_save_bigtiff'])) self.similarity_radius_entry.setText(str(parameters.params['advanced_nlmdn_sim_search_radius'])) self.patch_radius_entry.setText(str(parameters.params['advanced_nlmdn_patch_radius'])) self.smoothing_entry.setText(str(parameters.params['advanced_nlmdn_smoothing_control'])) self.noise_std_entry.setText(str(parameters.params['advanced_nlmdn_noise_std'])) self.window_entry.setText(str(parameters.params['advanced_nlmdn_window'])) self.fast_checkbox.setChecked(bool(parameters.params['advanced_nlmdn_fast'])) self.sigma_checkbox.setChecked(bool(parameters.params['advanced_nlmdn_estimate_sigma'])) def set_apply_to_reco(self): LOG.debug( "Apply NLMDN to reconstructed slices checkbox: " + str(self.apply_to_reco_checkbox.isChecked()) ) parameters.params['advanced_nlmdn_apply_after_reco'] = bool( self.apply_to_reco_checkbox.isChecked() ) if self.apply_to_reco_checkbox.isChecked(): self.input_dir_button.setDisabled(True) self.select_img_button.setDisabled(True) self.input_dir_entry.setDisabled(True) self.dry_button.setDisabled(True) self.apply_button.setDisabled(True) self.output_dir_button.setDisabled(True) self.output_dir_entry.setDisabled(True) elif not self.apply_to_reco_checkbox.isChecked(): self.input_dir_button.setDisabled(False) self.select_img_button.setDisabled(False) self.input_dir_entry.setDisabled(False) self.dry_button.setDisabled(False) self.apply_button.setDisabled(False) self.output_dir_button.setDisabled(False) self.output_dir_entry.setDisabled(False) def set_indir_button(self): """ Saves directory specified by user in file-dialog for input tomographic data """ LOG.debug("Select input directory pressed") dir_explore = QFileDialog(self) directory = dir_explore.getExistingDirectory() self.input_dir_entry.setText(directory) parameters.params['advanced_nlmdn_input_dir'] = directory self.output_dir_entry.setText(directory + "-nlmfilt") parameters.params['advanced_nlmdn_output_dir'] = directory + "-nlmfilt" parameters.params['advanced_nlmdn_input_is_file'] = False def set_indir_entry(self): LOG.debug("Indir entry: " + str(self.input_dir_entry.text())) parameters.params['advanced_nlmdn_input_dir'] = str(self.input_dir_entry.text()) def select_image(self): LOG.debug("Select one image button pressed") options = QFileDialog.Options() file_path, _ = QFileDialog.getOpenFileName( self, "Open .tif Image File", "", "Tiff Files (*.tif *.tiff)", options=options ) if file_path: img_name, img_ext = os.path.splitext(file_path) tmp = img_name + "-nlmfilt-%05i" + img_ext self.input_dir_entry.setText(file_path) self.output_dir_entry.setText(tmp) parameters.params['advanced_nlmdn_input_dir'] = file_path parameters.params['advanced_nlmdn_output_dir'] = tmp parameters.params['advanced_nlmdn_input_is_file'] = True def set_outdir_button(self): LOG.debug("Select output directory pressed") dir_explore = QFileDialog(self) directory = dir_explore.getExistingDirectory() self.output_dir_entry.setText(directory) parameters.params['advanced_nlmdn_output_dir'] = directory def set_save_bigtif(self): LOG.debug("Save bigtif checkbox: " + str(self.save_bigtif_checkbox.isChecked())) parameters.params['advanced_nlmdn_save_bigtiff'] = bool(self.save_bigtif_checkbox.isChecked()) def set_outdir_entry(self): LOG.debug("Outdir entry: " + str(self.output_dir_entry.text())) parameters.params['advanced_nlmdn_output_dir'] = str(self.output_dir_entry.text()) def set_rad_sim_entry(self): LOG.debug("Radius for similarity: " + str(self.similarity_radius_entry.text())) parameters.params['advanced_nlmdn_sim_search_radius'] = str(self.similarity_radius_entry.text()) def set_rad_patch_entry(self): LOG.debug("Radius of patches: " + str(self.patch_radius_entry.text())) parameters.params['advanced_nlmdn_patch_radius'] = str(self.patch_radius_entry.text()) def set_smoothing_entry(self): LOG.debug("Smoothing control: " + str(self.smoothing_entry.text())) parameters.params['advanced_nlmdn_smoothing_control'] = str(self.smoothing_entry.text()) def set_noise_entry(self): LOG.debug("Noise std: " + str(self.noise_std_entry.text())) parameters.params['advanced_nlmdn_noise_std'] = str(self.noise_std_entry.text()) def set_window_entry(self): LOG.debug("Window: " + str(self.window_entry.text())) parameters.params['advanced_nlmdn_window'] = str(self.window_entry.text()) def set_fast_checkbox(self): LOG.debug("Fast: " + str(self.fast_checkbox.isChecked())) parameters.params['advanced_nlmdn_fast'] = bool(self.fast_checkbox.isChecked()) def set_sigma_checkbox(self): LOG.debug("Estimate sigma: " + str(self.sigma_checkbox.isChecked())) parameters.params['advanced_nlmdn_estimate_sigma'] = bool(self.sigma_checkbox.isChecked()) def help_button_pressed(self): LOG.debug("Help Button Pressed") h = "" h += 'Note4: set to "flats" if "flats2" exist but you need to ignore them; \n' h += "SerG, BMIT CLS, Dec. 2020." QMessageBox.information(self, "Help", h) def delete_button_pressed(self): LOG.debug("Delete Reco Button Pressed") """ Deletes the directory that contains reconstructed data """ LOG.debug("DELETE") msg = "Delete directory with reconstructed data?" dialog = QMessageBox.warning(self, "Warning: data can be lost", msg, QMessageBox.Yes | QMessageBox.No) if dialog == QMessageBox.Yes: if os.path.exists(str(parameters.params['advanced_nlmdn_output_dir'])): LOG.debug("YES") if parameters.params['advanced_nlmdn_output_dir'] == parameters.params['advanced_nlmdn_input_dir']: LOG.debug("Cannot delete: output directory is the same as input") else: rmtree(parameters.params['advanced_nlmdn_output_dir']) LOG.debug("Directory with reconstructed data was removed") else: LOG.debug("Directory does not exist") else: LOG.debug("NO") def dry_button_pressed(self): LOG.debug("Dry Run Button Pressed") parameters.params['advanced_nlmdn_dry_run'] = True self.apply_button_pressed() parameters.params['advanced_nlmdn_dry_run'] = False def apply_button_pressed(self): LOG.debug("Apply Filter Button Pressed") args = tk_args(parameters.params['advanced_nlmdn_apply_after_reco'], parameters.params['advanced_nlmdn_input_dir'], parameters.params['advanced_nlmdn_input_is_file'], parameters.params['advanced_nlmdn_output_dir'], parameters.params['advanced_nlmdn_save_bigtiff'], parameters.params['advanced_nlmdn_sim_search_radius'], parameters.params['advanced_nlmdn_patch_radius'], parameters.params['advanced_nlmdn_smoothing_control'], parameters.params['advanced_nlmdn_noise_std'], parameters.params['advanced_nlmdn_window'], parameters.params['advanced_nlmdn_fast'], parameters.params['advanced_nlmdn_estimate_sigma'], parameters.params['advanced_nlmdn_dry_run']) #LOG.debug(args.args) if os.path.exists(args.outdir) and not args.dryrun: title_text = "Warning: files can be overwritten" text1 = "Output directory exists. Files can be overwritten. Proceed?" dialog = QMessageBox.warning(self, title_text, text1, QMessageBox.Yes | QMessageBox.No) if dialog == QMessageBox.Yes: main_tk(args) QMessageBox.information(self, "Finished", "Finished") else: main_tk(args) QMessageBox.information(self, "Finished", "Finished") class tk_args: def __init__( self, e_apply_after_reco, e_indir, e_input_is_file, e_outdir, e_bigtif, e_r, e_dx, e_h, e_sig, e_w, e_fast, e_autosig, e_dryrun, ): self.args = {} # PATHS self.args["apply_after_reco"] = str(e_apply_after_reco) setattr(self, "apply_after_reco", self.args["apply_after_reco"]) self.args["indir"] = str(e_indir) setattr(self, "indir", self.args["indir"]) self.args["input_is_file"] = e_input_is_file setattr(self, "input_is_file", self.args["input_is_file"]) self.args["outdir"] = str(e_outdir) setattr(self, "outdir", self.args["outdir"]) # ALG PARAMS - MAIN self.args["search_r"] = int(e_r) setattr(self, "search_r", self.args["search_r"]) self.args["patch_r"] = int(e_dx) setattr(self, "patch_r", self.args["patch_r"]) self.args["h"] = float(e_h) setattr(self, "h", self.args["h"]) self.args["sig"] = float(e_sig) setattr(self, "sig", self.args["sig"]) # ALG PARAMS - optional self.args["w"] = float(e_w) setattr(self, "w", self.args["w"]) self.args["fast"] = bool(e_fast) setattr(self, "fast", self.args["fast"]) self.args["autosig"] = bool(e_autosig) setattr(self, "autosig", self.args["autosig"]) # Misc # self.args['inplace'] = bool(e_inplace.get()) # setattr(self, 'inplace', self.args['inplace']) self.args["bigtif"] = bool(e_bigtif) setattr(self, "bigtif", self.args["bigtif"]) self.args["dryrun"] = bool(e_dryrun) setattr(self, "dryrun", self.args["dryrun"]) tofu-0.12.0/tofu/ez/GUI/Advanced/optimization.py000066400000000000000000000075561423713721100214110ustar00rootroot00000000000000import logging from PyQt5.QtWidgets import QGridLayout, QLabel, QGroupBox, QLineEdit, QCheckBox import tofu.ez.params as parameters LOG = logging.getLogger(__name__) class OptimizationGroup(QGroupBox): """ Optimization settings """ def __init__(self): super().__init__() self.setTitle("Optimization Settings") self.setStyleSheet("QGroupBox {color: orange;}") self.verbose_switch = QCheckBox("Enable verbose console output") self.verbose_switch.stateChanged.connect(self.set_verbose_switch) self.slice_memory_label = QLabel("Slice memory coefficient") self.slice_memory_entry = QLineEdit() tmpstr="Fraction of VRAM which will be used to store images \n" \ "Reserve ~2 GB of VRAM for computation \n" \ "Decrease the coefficient if you have very large data and start getting errors" self.slice_memory_entry.setToolTip(tmpstr) self.slice_memory_label.setToolTip(tmpstr) self.slice_memory_entry.editingFinished.connect(self.set_slice) self.num_GPU_label = QLabel("Number of GPUs") self.num_GPU_entry = QLineEdit() self.num_GPU_entry.editingFinished.connect(self.set_num_gpu) self.slices_per_device_label = QLabel("Slices per device") self.slices_per_device_entry = QLineEdit() self.slices_per_device_entry.editingFinished.connect(self.set_slices_per_device) self.set_layout() def set_layout(self): layout = QGridLayout() layout.addWidget(self.verbose_switch, 0, 0) gpu_group = QGroupBox("GPU optimization") gpu_group.setCheckable(True) gpu_group.setChecked(False) gpu_layout = QGridLayout() gpu_layout.addWidget(self.slice_memory_label, 0, 0) gpu_layout.addWidget(self.slice_memory_entry, 0, 1) gpu_layout.addWidget(self.num_GPU_label, 1, 0) gpu_layout.addWidget(self.num_GPU_entry, 1, 1) gpu_layout.addWidget(self.slices_per_device_label, 2, 0) gpu_layout.addWidget(self.slices_per_device_entry, 2, 1) gpu_group.setLayout(gpu_layout) layout.addWidget(gpu_group, 1, 0) self.setLayout(layout) def init_values(self): self.verbose_switch.setChecked(False) parameters.params['advanced_optimize_verbose_console'] = False parameters.params['advanced_optimize_slice_mem_coeff'] = 0.7 self.slice_memory_entry.setText( str(parameters.params['advanced_optimize_slice_mem_coeff'])) self.num_GPU_entry.setText("") parameters.params['advanced_optimize_num_gpus'] = "" self.slices_per_device_entry.setText("") parameters.params['advanced_optimize_slices_per_device'] = "" def set_values_from_params(self): self.verbose_switch.setChecked(bool(parameters.params['advanced_optimize_verbose_console'])) self.slice_memory_entry.setText(str(parameters.params['advanced_optimize_slice_mem_coeff'])) self.num_GPU_entry.setText(str(parameters.params['advanced_optimize_num_gpus'])) self.slices_per_device_entry.setText(str(parameters.params['advanced_optimize_slices_per_device'])) def set_verbose_switch(self): LOG.debug("Verbose: " + str(self.verbose_switch.isChecked())) parameters.params['advanced_optimize_verbose_console'] = bool(self.verbose_switch.isChecked()) def set_slice(self): LOG.debug(self.slice_memory_entry.text()) parameters.params['advanced_optimize_slice_mem_coeff'] = str(self.slice_memory_entry.text()) def set_num_gpu(self): LOG.debug(self.num_GPU_entry.text()) parameters.params['advanced_optimize_num_gpus'] = str(self.num_GPU_entry.text()) def set_slices_per_device(self): LOG.debug(self.slices_per_device_entry.text()) parameters.params['advanced_optimize_slices_per_device'] = str(self.slices_per_device_entry.text())tofu-0.12.0/tofu/ez/GUI/Main/000077500000000000000000000000001423713721100154535ustar00rootroot00000000000000tofu-0.12.0/tofu/ez/GUI/Main/__init__.py000066400000000000000000000000001423713721100175520ustar00rootroot00000000000000tofu-0.12.0/tofu/ez/GUI/Main/centre_of_rotation.py000066400000000000000000000176041423713721100217200ustar00rootroot00000000000000import logging from PyQt5.QtWidgets import QGridLayout, QLabel, QRadioButton, QGroupBox, QLineEdit import tofu.ez.params as parameters LOG = logging.getLogger(__name__) class CentreOfRotationGroup(QGroupBox): """ Centre of Rotation settings """ def __init__(self): super().__init__() self.setTitle("Centre of Rotation") self.setStyleSheet("QGroupBox {color: green;}") self.auto_correlate_rButton = QRadioButton() self.auto_correlate_rButton.setText("Auto: Correlate first/last projections") self.auto_correlate_rButton.clicked.connect(self.set_rButton) self.auto_minimize_rButton = QRadioButton() self.auto_minimize_rButton.setText("Auto: Minimize STD of a slice") self.auto_minimize_rButton.setToolTip( "Reconstructed patches are saved \nin your-temporary-data-folder\\axis-search" ) self.auto_minimize_rButton.clicked.connect(self.set_rButton) self.define_axis_rButton = QRadioButton() self.define_axis_rButton.setText("Define rotation axis manually") self.define_axis_rButton.clicked.connect(self.set_rButton) self.search_rotation_label = QLabel() self.search_rotation_label.setText("Search rotation axis in [start, stop, step] interval") self.search_rotation_entry = QLineEdit() self.search_rotation_entry.editingFinished.connect(self.set_search_rotation) self.search_rotation_entry.setStyleSheet("background-color:white") self.search_in_slice_label = QLabel() self.search_in_slice_label.setText("Search in slice from row number") self.search_in_slice_entry = QLineEdit() self.search_in_slice_entry.editingFinished.connect(self.set_search_slice) self.search_in_slice_entry.setStyleSheet("background-color:white") self.size_of_recon_label = QLabel() self.size_of_recon_label.setText("Size of reconstructed patch [pixel]") self.size_of_recon_entry = QLineEdit() self.size_of_recon_entry.editingFinished.connect(self.set_size_of_reco) self.size_of_recon_entry.setStyleSheet("background-color:white") self.axis_col_label = QLabel() self.axis_col_label.setText("Axis is in column No [pixel]") self.axis_col_entry = QLineEdit() self.axis_col_entry.editingFinished.connect(self.set_axis_col) self.axis_col_entry.setStyleSheet("background-color:white") self.inc_axis_label = QLabel() self.inc_axis_label.setText("Increment axis every reconstruction") self.inc_axis_entry = QLineEdit() self.inc_axis_entry.editingFinished.connect(self.set_axis_inc) self.inc_axis_entry.setStyleSheet("background-color:white") self.image_midpoint_rButton = QRadioButton() self.image_midpoint_rButton.setText("Use image midpoint (for half-acquisition)") self.image_midpoint_rButton.clicked.connect(self.set_rButton) # TODO Used for proper spacing - should be a better way self.blank_label = QLabel(" ") self.blank_label2 = QLabel(" ") self.set_layout() def set_layout(self): layout = QGridLayout() layout.addWidget(self.auto_correlate_rButton, 0, 0) layout.addWidget(self.blank_label, 0, 1) layout.addWidget(self.blank_label2, 0, 2) layout.addWidget(self.auto_minimize_rButton, 1, 0) layout.addWidget(self.search_rotation_label, 2, 0) layout.addWidget(self.search_rotation_entry, 2, 1, 1, 2) layout.addWidget(self.search_in_slice_label, 3, 0) layout.addWidget(self.search_in_slice_entry, 3, 1, 1, 2) layout.addWidget(self.size_of_recon_label, 4, 0) layout.addWidget(self.size_of_recon_entry, 4, 1, 1, 2) layout.addWidget(self.define_axis_rButton, 5, 0) layout.addWidget(self.axis_col_label, 6, 0) layout.addWidget(self.axis_col_entry, 6, 1, 1, 2) layout.addWidget(self.inc_axis_label, 7, 0) layout.addWidget(self.inc_axis_entry, 7, 1, 1, 2) layout.addWidget(self.image_midpoint_rButton, 8, 0) self.setLayout(layout) def init_values(self): self.auto_correlate_rButton.setChecked(True) self.auto_minimize_rButton.setChecked(False) self.define_axis_rButton.setChecked(False) self.image_midpoint_rButton.setChecked(False) self.set_rButton() self.search_rotation_entry.setText("1010,1030,0.5") self.search_in_slice_entry.setText("100") self.size_of_recon_entry.setText("256") self.axis_col_entry.setText("0.0") self.inc_axis_entry.setText("0.0") # self.bypass_checkbox.setChecked(False) def set_values_from_params(self): self.set_rButton_from_params() self.search_rotation_entry.setText(str(parameters.params['main_cor_axis_search_interval'])) self.search_in_slice_entry.setText(str(parameters.params['main_cor_search_row_start'])) self.size_of_recon_entry.setText(str(parameters.params['main_cor_recon_patch_size'])) self.axis_col_entry.setText(str(parameters.params['main_cor_axis_column'])) self.inc_axis_entry.setText(str(parameters.params['main_cor_axis_increment_step'])) def set_rButton(self): if self.auto_correlate_rButton.isChecked(): LOG.debug("Auto Correlate") parameters.params['main_cor_axis_search_method'] = 1 elif self.auto_minimize_rButton.isChecked(): LOG.debug("Auto Minimize") parameters.params['main_cor_axis_search_method'] = 2 elif self.define_axis_rButton.isChecked(): LOG.debug("Define axis") parameters.params['main_cor_axis_search_method'] = 3 elif self.image_midpoint_rButton.isChecked(): LOG.debug("Use image midpoint") parameters.params['main_cor_axis_search_method'] = 4 def set_rButton_from_params(self): if parameters.params['main_cor_axis_search_method'] == 1: self.auto_correlate_rButton.setChecked(True) self.auto_minimize_rButton.setChecked(False) self.define_axis_rButton.setChecked(False) self.image_midpoint_rButton.setChecked(False) elif parameters.params['main_cor_axis_search_method'] == 2: self.auto_correlate_rButton.setChecked(False) self.auto_minimize_rButton.setChecked(True) self.define_axis_rButton.setChecked(False) self.image_midpoint_rButton.setChecked(False) elif parameters.params['main_cor_axis_search_method'] == 3: self.auto_correlate_rButton.setChecked(False) self.auto_minimize_rButton.setChecked(False) self.define_axis_rButton.setChecked(True) self.image_midpoint_rButton.setChecked(False) elif parameters.params['main_cor_axis_search_method'] == 4: self.auto_correlate_rButton.setChecked(False) self.auto_minimize_rButton.setChecked(False) self.define_axis_rButton.setChecked(False) self.image_midpoint_rButton.setChecked(True) def set_search_rotation(self): LOG.debug(self.search_rotation_entry.text()) parameters.params['main_cor_axis_search_interval'] = str(self.search_rotation_entry.text()) def set_search_slice(self): LOG.debug(self.search_in_slice_entry.text()) parameters.params['main_cor_search_row_start'] = str(self.search_in_slice_entry.text()) def set_size_of_reco(self): LOG.debug(self.size_of_recon_entry.text()) parameters.params['main_cor_recon_patch_size'] = str(self.size_of_recon_entry.text()) def set_axis_col(self): LOG.debug(self.axis_col_entry.text()) parameters.params['main_cor_axis_column'] = str(self.axis_col_entry.text()) def set_axis_inc(self): LOG.debug(self.inc_axis_entry.text()) parameters.params['main_cor_axis_increment_step'] = str(self.inc_axis_entry.text())tofu-0.12.0/tofu/ez/GUI/Main/config.py000066400000000000000000001616231423713721100173030ustar00rootroot00000000000000import os import logging import numpy as np from shutil import rmtree from PyQt5.QtWidgets import ( QMessageBox, QFileDialog, QCheckBox, QPushButton, QGridLayout, QLabel, QGroupBox, QLineEdit, ) from PyQt5.QtCore import QCoreApplication from PyQt5.QtCore import pyqtSignal from PyQt5.QtCore import Qt from tofu.ez.main import execute_reconstruction, clean_tmp_dirs from tofu.ez.yaml_in_out import Yaml_IO from tofu.ez.GUI.message_dialog import warning_message import tofu.ez.params as parameters #TODO Get rid of the old args structure and store all parameters # like tofu does LOG = logging.getLogger(__name__) class ConfigGroup(QGroupBox): """ Setup and configuration settings """ # Used to send signal to ezufo_launcher when settings are imported https://stackoverflow.com/questions/2970312/pyqt4-qtcore-pyqtsignal-object-has-no-attribute-connect signal_update_vals_from_params = pyqtSignal(dict) # Used to send signal when reconstruction is done signal_reco_done = pyqtSignal(dict) def __init__(self): super().__init__() self.setTitle("Configuration") self.setStyleSheet("QGroupBox {color: purple;}") self.yaml_io = Yaml_IO() # Select input directory self.input_dir_select = QPushButton("Select input directory (or paste abs. path)") self.input_dir_select.setStyleSheet("background-color:lightgrey; font: 12pt;") self.input_dir_entry = QLineEdit() self.input_dir_entry.editingFinished.connect(self.set_input_dir) self.input_dir_select.pressed.connect(self.select_input_dir) # Save .params checkbox self.save_params_checkbox = QCheckBox("Save args in .params file") self.save_params_checkbox.stateChanged.connect(self.set_save_args) # Select output directory self.output_dir_select = QPushButton() self.output_dir_select.setText("Select output directory (or paste abs. path)") self.output_dir_select.setStyleSheet("background-color:lightgrey; font: 12pt;") self.output_dir_entry = QLineEdit() self.output_dir_entry.editingFinished.connect(self.set_output_dir) self.output_dir_select.pressed.connect(self.select_output_dir) # Save in separate files or in one huge tiff file self.bigtiff_checkbox = QCheckBox() self.bigtiff_checkbox.setText("Save slices in multipage tiffs") self.bigtiff_checkbox.setToolTip( "Will save images in bigtiff containers. \n" "Note that some temporary data is always saved in bigtiffs.\n" "Use bio-formats importer plugin for imagej or fiji to open the bigtiffs." ) self.bigtiff_checkbox.stateChanged.connect(self.set_big_tiff) # Crop in the reconstruction plane self.preproc_checkbox = QCheckBox() self.preproc_checkbox.setText("Preprocess with a generic ufo-launch pipeline, f.i.") self.preproc_checkbox.setToolTip( "Selected ufo filters will be applied to each " "image before reconstruction begins. \n" 'To print the list of filters use "ufo-query -l" command. \n' 'Parameters of each filter can be seen with "ufo-query -p filtername".' ) self.preproc_checkbox.stateChanged.connect(self.set_preproc) self.preproc_entry = QLineEdit() self.preproc_entry.editingFinished.connect(self.set_preproc_entry) # Names of directories with flats/darks/projections frames self.e_DIRTYP = ["darks", "flats", "tomo", "flats2"] self.dir_name_label = QLabel() self.dir_name_label.setText("Name of flats/darks/tomo subdirectories in each CT data set") self.darks_entry = QLineEdit() self.darks_entry.editingFinished.connect(self.set_darks) self.flats_entry = QLineEdit() self.flats_entry.editingFinished.connect(self.set_flats) self.tomo_entry = QLineEdit() self.tomo_entry.editingFinished.connect(self.set_tomo) self.flats2_entry = QLineEdit() self.flats2_entry.editingFinished.connect(self.set_flats2) # Select flats/darks/flats2 for use in multiple reconstructions self.use_common_flats_darks_checkbox = QCheckBox() self.use_common_flats_darks_checkbox.setText( "Use common flats/darks across multiple experiments" ) self.use_common_flats_darks_checkbox.stateChanged.connect(self.set_flats_darks_checkbox) self.select_darks_button = QPushButton("Select path to darks (or paste abs. path)") self.select_darks_button.setToolTip("Background detector noise") self.select_darks_button.clicked.connect(self.select_darks_button_pressed) self.select_flats_button = QPushButton("Select path to flats (or paste abs. path)") self.select_flats_button.setToolTip("Images without sample in the beam") self.select_flats_button.clicked.connect(self.select_flats_button_pressed) self.select_flats2_button = QPushButton("Select path to flats2 (or paste abs. path)") self.select_flats2_button.setToolTip( "If selected, it will be assumed that flats were \n" "acquired before projections while flats2 after \n" "and interpolation will be used to compute intensity of flat image \n" "for each projection between flats and flats2" ) self.select_flats2_button.clicked.connect(self.select_flats2_button_pressed) self.darks_absolute_entry = QLineEdit() self.darks_absolute_entry.setText("Absolute path to darks") self.darks_absolute_entry.editingFinished.connect(self.set_common_darks) self.flats_absolute_entry = QLineEdit() self.flats_absolute_entry.setText("Absolute path to flats") self.flats_absolute_entry.editingFinished.connect(self.set_common_flats) self.use_flats2_checkbox = QCheckBox("Use common flats2") self.use_flats2_checkbox.clicked.connect(self.set_use_flats2) self.flats2_absolute_entry = QLineEdit() self.flats2_absolute_entry.editingFinished.connect(self.set_common_flats2) self.flats2_absolute_entry.setText("Absolute path to flats2") # Select temporary directory self.temp_dir_select = QPushButton() self.temp_dir_select.setText("Select temporary directory (or paste abs. path)") self.temp_dir_select.setToolTip( "Temporary data will be saved there.\n" "note that the size of temporary data can exceed 300 GB in some cases." ) self.temp_dir_select.pressed.connect(self.select_temp_dir) self.temp_dir_select.setStyleSheet("background-color:lightgrey; font: 12pt;") self.temp_dir_entry = QLineEdit() self.temp_dir_entry.editingFinished.connect(self.set_temp_dir) # Keep temp data selection self.keep_tmp_data_checkbox = QCheckBox() self.keep_tmp_data_checkbox.setText("Keep all temp data till the end of reconstruction") self.keep_tmp_data_checkbox.setToolTip( "Useful option to inspect how images change at each step" ) self.keep_tmp_data_checkbox.stateChanged.connect(self.set_keep_tmp_data) # IMPORT SETTINGS FROM FILE self.open_settings_file = QPushButton() self.open_settings_file.setText("Import parameters from file") self.open_settings_file.setStyleSheet("background-color:lightgrey; font: 12pt;") self.open_settings_file.pressed.connect(self.import_settings_button_pressed) # EXPORT SETTINGS TO FILE self.save_settings_file = QPushButton() self.save_settings_file.setText("Export parameters to file") self.save_settings_file.setStyleSheet("background-color:lightgrey; font: 12pt;") self.save_settings_file.pressed.connect(self.export_settings_button_pressed) # QUIT self.quit_button = QPushButton() self.quit_button.setText("Quit") self.quit_button.setStyleSheet("background-color:lightgrey; font: 13pt; font-weight: bold;") self.quit_button.clicked.connect(self.quit_button_pressed) # HELP self.help_button = QPushButton() self.help_button.setText("Help") self.help_button.setStyleSheet("background-color:lightgrey; font: 13pt; font-weight: bold") self.help_button.clicked.connect(self.help_button_pressed) # DELETE self.delete_reco_dir_button = QPushButton() self.delete_reco_dir_button.setText("Delete reco dir") self.delete_reco_dir_button.setStyleSheet( "background-color:lightgrey; font: 13pt; font-weight: bold" ) self.delete_reco_dir_button.clicked.connect(self.delete_button_pressed) # DRY RUN self.dry_run_button = QPushButton() self.dry_run_button.setText("Dry run") self.dry_run_button.setStyleSheet( "background-color:lightgrey; font: 13pt; font-weight: bold" ) self.dry_run_button.clicked.connect(self.dryrun_button_pressed) # RECONSTRUCT self.reco_button = QPushButton() self.reco_button.setText("Reconstruct") self.reco_button.setStyleSheet( "background-color:lightgrey;color:royalblue; font: 14pt; font-weight: bold;" ) self.reco_button.clicked.connect(self.reco_button_pressed) # OPEN IMAGE AFTER RECONSTRUCT self.open_image_after_reco_checkbox = QCheckBox() self.open_image_after_reco_checkbox.setText( "Load images and open viewer after reconstruction" ) self.open_image_after_reco_checkbox.clicked.connect(self.set_open_image_after_reco) self.set_layout() def set_layout(self): """ Sets the layout of buttons, labels, etc. for config group """ layout = QGridLayout() checkbox_groupbox = QGroupBox() checkbox_layout = QGridLayout() checkbox_layout.addWidget(self.save_params_checkbox, 0, 0) checkbox_layout.addWidget(self.bigtiff_checkbox, 1, 0) checkbox_layout.addWidget(self.open_image_after_reco_checkbox, 2, 0) checkbox_layout.addWidget(self.keep_tmp_data_checkbox, 3, 0) checkbox_groupbox.setLayout(checkbox_layout) layout.addWidget(checkbox_groupbox, 0, 4, 4, 1) layout.addWidget(self.input_dir_select, 0, 0) layout.addWidget(self.input_dir_entry, 0, 1, 1, 3) layout.addWidget(self.output_dir_select, 1, 0) layout.addWidget(self.output_dir_entry, 1, 1, 1, 3) layout.addWidget(self.temp_dir_select, 2, 0) layout.addWidget(self.temp_dir_entry, 2, 1, 1, 3) layout.addWidget(self.preproc_checkbox, 3, 0) layout.addWidget(self.preproc_entry, 3, 1, 1, 3) fdt_groupbox = QGroupBox() fdt_layout = QGridLayout() fdt_layout.addWidget(self.dir_name_label, 0, 0) fdt_layout.addWidget(self.darks_entry, 0, 1) fdt_layout.addWidget(self.flats_entry, 0, 2) fdt_layout.addWidget(self.tomo_entry, 0, 3) fdt_layout.addWidget(self.flats2_entry, 0, 4) fdt_layout.addWidget(self.use_common_flats_darks_checkbox, 1, 0) fdt_layout.addWidget(self.select_darks_button, 1, 1) fdt_layout.addWidget(self.select_flats_button, 1, 2) fdt_layout.addWidget(self.select_flats2_button, 1, 4) fdt_layout.addWidget(self.darks_absolute_entry, 2, 1) fdt_layout.addWidget(self.flats_absolute_entry, 2, 2) fdt_layout.addWidget(self.use_flats2_checkbox, 2, 3, Qt.AlignRight) fdt_layout.addWidget(self.flats2_absolute_entry, 2, 4) fdt_groupbox.setLayout(fdt_layout) layout.addWidget(fdt_groupbox, 4, 0, 1, 5) layout.addWidget(self.open_settings_file, 5, 0, 1, 3) layout.addWidget(self.save_settings_file, 5, 3, 1, 2) layout.addWidget(self.quit_button, 6, 0) layout.addWidget(self.help_button, 6, 1) layout.addWidget(self.delete_reco_dir_button, 6, 2) layout.addWidget(self.dry_run_button, 6, 3) layout.addWidget(self.reco_button, 6, 4) self.setLayout(layout) def init_values(self): """ Sets the initial default values of config group """ # If we're on a computer with access to network indir = os.path.expanduser('~')#"/beamlinedata/BMIT/projects/" if os.path.isdir(indir): self.input_dir_entry.setText(indir) outdir = os.path.abspath(indir + "/rec") self.output_dir_entry.setText(outdir) # Otherwise use this as default self.save_params_checkbox.setChecked(True) parameters.params['main_config_save_params'] = True parameters.params['main_config_save_multipage_tiff'] = False self.preproc_checkbox.setChecked(False) self.set_preproc() parameters.params['main_config_preprocess'] = False self.preproc_entry.setText("remove-outliers size=3 threshold=500 sign=1") self.darks_entry.setText("darks") self.flats_entry.setText("flats") self.tomo_entry.setText("tomo") self.flats2_entry.setText("flats2") self.use_common_flats_darks_checkbox.setChecked(False) self.darks_absolute_entry.setText("Absolute path to darks") self.flats_absolute_entry.setText("Absolute path to flats") self.use_common_flats_darks_checkbox.setChecked(False) self.flats2_absolute_entry.setText("Absolute path to flats2") self.temp_dir_entry.setText(os.path.join(os.path.expanduser('~'),"tmp-ezufo")) self.keep_tmp_data_checkbox.setChecked(False) parameters.params['main_config_keep_temp'] = False self.set_temp_dir() self.dry_run_button.setChecked(False) parameters.params['main_config_dry_run'] = False parameters.params['main_config_open_viewer'] = False self.open_image_after_reco_checkbox.setChecked(False) def set_values_from_params(self): """ Updates displayed values for config group Called when .yaml file of params is loaded """ self.input_dir_entry.setText(parameters.params['main_config_input_dir']) self.save_params_checkbox.setChecked(parameters.params['main_config_save_params']) self.output_dir_entry.setText(parameters.params['main_config_output_dir']) self.bigtiff_checkbox.setChecked(parameters.params['main_config_save_multipage_tiff']) self.preproc_checkbox.setChecked(parameters.params['main_config_preprocess']) self.preproc_entry.setText(parameters.params['main_config_preprocess_command']) self.darks_entry.setText(parameters.params['main_config_darks_dir_name']) self.flats_entry.setText(parameters.params['main_config_flats_dir_name']) self.tomo_entry.setText(parameters.params['main_config_tomo_dir_name']) self.flats2_entry.setText(parameters.params['main_config_flats2_dir_name']) self.temp_dir_entry.setText(parameters.params['main_config_temp_dir']) self.keep_tmp_data_checkbox.setChecked(parameters.params['main_config_keep_temp']) self.dry_run_button.setChecked(parameters.params['main_config_dry_run']) self.open_image_after_reco_checkbox.setChecked(parameters.params['main_config_open_viewer']) self.use_common_flats_darks_checkbox.setChecked(parameters.params['main_config_common_flats_darks']) self.darks_absolute_entry.setText(parameters.params['main_config_darks_path']) self.flats_absolute_entry.setText(parameters.params['main_config_flats_path']) self.use_flats2_checkbox.setChecked(parameters.params['main_config_flats2_checkbox']) self.flats2_absolute_entry.setText(parameters.params['main_config_flats2_path']) def select_input_dir(self): """ Saves directory specified by user in file-dialog for input tomographic data """ dir_explore = QFileDialog(self) dir = dir_explore.getExistingDirectory(directory=self.input_dir_entry.text()) self.input_dir_entry.setText(dir) parameters.params['main_config_input_dir'] = dir def set_input_dir(self): LOG.debug(str(self.input_dir_entry.text())) parameters.params['main_config_input_dir'] = str(self.input_dir_entry.text()) def select_output_dir(self): dir_explore = QFileDialog(self) dir = dir_explore.getExistingDirectory(directory=self.output_dir_entry.text()) self.output_dir_entry.setText(dir) parameters.params['main_config_output_dir'] = dir def set_output_dir(self): LOG.debug(str(self.output_dir_entry.text())) parameters.params['main_config_output_dir'] = str(self.output_dir_entry.text()) def set_big_tiff(self): LOG.debug("Bigtiff: " + str(self.bigtiff_checkbox.isChecked())) parameters.params['main_config_save_multipage_tiff'] = bool(self.bigtiff_checkbox.isChecked()) def set_preproc(self): LOG.debug("Preproc: " + str(self.preproc_checkbox.isChecked())) parameters.params['main_config_preprocess'] = bool(self.preproc_checkbox.isChecked()) def set_preproc_entry(self): LOG.debug(self.preproc_entry.text()) parameters.params['main_config_preprocess_command'] = str(self.preproc_entry.text()) def set_open_image_after_reco(self): LOG.debug( "Switch to Image Viewer After Reco: " + str(self.open_image_after_reco_checkbox.isChecked()) ) parameters.params['main_config_open_viewer'] = bool(self.open_image_after_reco_checkbox.isChecked()) def set_darks(self): LOG.debug(self.darks_entry.text()) self.e_DIRTYP[0] = str(self.darks_entry.text()) parameters.params['main_config_darks_dir_name'] = str(self.darks_entry.text()) def set_flats(self): LOG.debug(self.flats_entry.text()) self.e_DIRTYP[1] = str(self.flats_entry.text()) parameters.params['main_config_flats_dir_name'] = str(self.flats_entry.text()) def set_tomo(self): LOG.debug(self.tomo_entry.text()) self.e_DIRTYP[2] = str(self.tomo_entry.text()) parameters.params['main_config_tomo_dir_name'] = str(self.tomo_entry.text()) def set_flats2(self): LOG.debug(self.flats2_entry.text()) self.e_DIRTYP[3] = str(self.flats2_entry.text()) parameters.params['main_config_flats2_dir_name'] = str(self.flats2_entry.text()) def set_fdt_names(self): self.set_darks() self.set_flats() self.set_flats2() self.set_tomo() def set_flats_darks_checkbox(self): LOG.debug( "Use same flats/darks across multiple experiments: " + str(self.use_common_flats_darks_checkbox.isChecked()) ) parameters.params['main_config_common_flats_darks'] = bool( self.use_common_flats_darks_checkbox.isChecked() ) def select_darks_button_pressed(self): LOG.debug("Select path to darks pressed") dir_explore = QFileDialog(self) directory = dir_explore.getExistingDirectory(directory=parameters.params['main_config_input_dir']) self.darks_absolute_entry.setText(directory) parameters.params['main_config_darks_path'] = directory def select_flats_button_pressed(self): LOG.debug("Select path to flats pressed") dir_explore = QFileDialog(self) directory = dir_explore.getExistingDirectory(directory=parameters.params['main_config_input_dir']) self.flats_absolute_entry.setText(directory) parameters.params['main_config_flats_path'] = directory def select_flats2_button_pressed(self): LOG.debug("Select path to flats2 pressed") dir_explore = QFileDialog(self) directory = dir_explore.getExistingDirectory(directory=parameters.params['main_config_input_dir']) self.flats2_absolute_entry.setText(directory) parameters.params['main_config_flats2_path'] = directory def set_common_darks(self): LOG.debug("Common darks path: " + str(self.darks_absolute_entry.text())) parameters.params['main_config_darks_path'] = str(self.darks_absolute_entry.text()) def set_common_flats(self): LOG.debug("Common flats path: " + str(self.flats_absolute_entry.text())) parameters.params['main_config_flats_path'] = str(self.flats_absolute_entry.text()) def set_use_flats2(self): LOG.debug("Use common flats2 checkbox: " + str(self.use_flats2_checkbox.isChecked())) parameters.params['main_config_flats2_checkbox'] = bool(self.use_flats2_checkbox.isChecked()) def set_common_flats2(self): LOG.debug("Common flats2 path: " + str(self.flats2_absolute_entry.text())) parameters.params['main_config_flats2_path'] = str(self.flats2_absolute_entry.text()) def select_temp_dir(self): dir_explore = QFileDialog(self) tmp_dir = dir_explore.getExistingDirectory(directory=self.temp_dir_entry.text()) self.temp_dir_entry.setText(tmp_dir) def set_temp_dir(self): LOG.debug(str(self.temp_dir_entry.text())) parameters.params['main_config_temp_dir'] = str(self.temp_dir_entry.text()) def set_keep_tmp_data(self): LOG.debug("Keep tmp: " + str(self.keep_tmp_data_checkbox.isChecked())) parameters.params['main_config_keep_temp'] = bool(self.keep_tmp_data_checkbox.isChecked()) def quit_button_pressed(self): """ Displays confirmation dialog and cleans temporary directories """ LOG.debug("QUIT") reply = QMessageBox.question( self, "Quit", "Are you sure you want to quit?", QMessageBox.Yes | QMessageBox.No, QMessageBox.No, ) if reply == QMessageBox.Yes: # remove all directories with projections clean_tmp_dirs(parameters.params['main_config_temp_dir'], self.get_fdt_names()) # remove axis-search dir too tmp = os.path.join(parameters.params['main_config_temp_dir'], 'axis-search') QCoreApplication.instance().quit() else: pass def help_button_pressed(self): """ Displays pop-up help information """ LOG.debug("HELP") h = "This utility provides an interface to the ufo-kit software package.\n" h += "Use it for batch processing and optimization of reconstruction parameters.\n" h += "It creates a list of paths to all CT directories in the _input_ directory.\n" h += "A CT directory is defined as directory with at least \n" h += "_flats_, _darks_, _tomo_, and, optionally, _flats2_ subdirectories, \n" h += "which are not empty and contain only *.tif files. Names of CT\n" h += "directories are compared with the directory tree in the _output_ directory.\n" h += ( "(Note: relative directory tree in _input_ is preserved when writing results to the" " _output_.)\n" ) h += ( "Those CT sets will be reconstructed, whose names are not yet in the _output_" " directory." ) h += "Program will create an array of ufo/tofu commands according to defined parameters \n" h += ( "and then execute them sequentially. These commands can be also printed on the" " screen.\n" ) h += "Note2: if you bin in preprocess the center of rotation will change a lot; \n" h += 'Note4: set to "flats" if "flats2" exist but you need to ignore them; \n' h += ( "Created by Sergei Gasilov, BMIT CLS, Dec. 2018.\n Extended by Iain Emslie, Summer" " 2021." ) QMessageBox.information(self, "Help", h) def delete_button_pressed(self): """ Deletes the directory that contains reconstructed data """ LOG.debug("DELETE") msg = "Delete directory with reconstructed data?" dialog = QMessageBox.warning( self, "Warning: data can be lost", msg, QMessageBox.Yes | QMessageBox.No ) if dialog == QMessageBox.Yes: if os.path.exists(str(parameters.params['main_config_output_dir'])): LOG.debug("YES") if parameters.params['main_config_output_dir'] == parameters.params['main_config_input_dir']: LOG.debug("Cannot delete: output directory is the same as input") else: try: rmtree(parameters.params['main_config_output_dir']) except: warning_message('Error while deleting directory') LOG.debug("Directory with reconstructed data was removed") else: LOG.debug("Directory does not exist") else: LOG.debug("NO") def dryrun_button_pressed(self): """ Sets the dry-run parameter for Tofu to True and calls reconstruction """ LOG.debug("DRY") parameters.params['main_config_dry_run'] = str(True) self.reco_button_pressed() parameters.params['main_config_dry_run'] = bool(False) def set_save_args(self): LOG.debug("Save args: " + str(self.save_params_checkbox.isChecked())) parameters.params['main_config_save_params'] = bool(self.save_params_checkbox.isChecked()) def export_settings_button_pressed(self): """ Saves currently displayed GUI settings to an external .yaml file specified by user """ LOG.debug("Save settings pressed") options = QFileDialog.Options() fileName, _ = QFileDialog.getSaveFileName( self, "QFileDialog.getSaveFileName()", "", "YAML Files (*.yaml);; All Files (*)", options=options, ) if fileName: LOG.debug("Export YAML Path: " + fileName) file_extension = os.path.splitext(fileName) if file_extension[-1] == "": fileName = fileName + ".yaml" # Create and write to YAML file based on given fileName self.yaml_io.write_yaml(fileName, parameters.params) def import_settings_button_pressed(self): """ Loads external settings from .yaml file specified by user Signal is sent to enable updating of displayed GUI values """ LOG.debug("Import settings pressed") options = QFileDialog.Options() filePath, _ = QFileDialog.getOpenFileName( self, "QFileDialog.getOpenFileName()", "", "YAML Files (*.yaml);; All Files (*)", options=options, ) if filePath: LOG.debug("Import YAML Path: " + filePath) yaml_data = self.yaml_io.read_yaml(filePath) parameters.params = dict(yaml_data) self.signal_update_vals_from_params.emit(parameters.params) def reco_button_pressed(self): """ Gets the settings set by the user in the GUI These are then passed to execute_reconstruction """ LOG.debug("RECO") LOG.debug(parameters.params) self.run_reconstruction(parameters.params, batch_run=False) def run_reconstruction(self, params, batch_run): try: self.validate_input() args = tk_args(params['main_config_input_dir'], params['main_config_temp_dir'], params['main_config_output_dir'], params['main_config_save_multipage_tiff'], params['main_cor_axis_search_method'], params['main_cor_axis_search_interval'], params['main_cor_search_row_start'], params['main_cor_recon_patch_size'], params['main_cor_axis_column'], params['main_cor_axis_increment_step'], params['main_filters_remove_spots'], params['main_filters_remove_spots_threshold'], params['main_filters_remove_spots_blur_sigma'], params['main_filters_ring_removal'], params['main_filters_ring_removal_ufo_lpf'], params['main_filters_ring_removal_ufo_lpf_1d_or_2d'], params['main_filters_ring_removal_ufo_lpf_sigma_horizontal'], params['main_filters_ring_removal_ufo_lpf_sigma_vertical'], params['main_filters_ring_removal_sarepy_window_size'], params['main_filters_ring_removal_sarepy_wide'], params['main_filters_ring_removal_sarepy_window'], params['main_filters_ring_removal_sarepy_SNR'], params['main_pr_phase_retrieval'], params['main_pr_photon_energy'], params['main_pr_pixel_size'], params['main_pr_detector_distance'], params['main_pr_delta_beta_ratio'], params['main_region_select_rows'], params['main_region_first_row'], params['main_region_number_rows'], params['main_region_nth_row'], params['main_region_clip_histogram'], params['main_region_bit_depth'], params['main_region_histogram_min'], params['main_region_histogram_max'], params['main_config_preprocess'], params['main_config_preprocess_command'], params['main_region_rotate_volume_clock'], params['main_region_crop_slices'], params['main_region_crop_x'], params['main_region_crop_width'], params['main_region_crop_y'], params['main_region_crop_height'], params['main_config_dry_run'], params['main_config_save_params'], params['main_config_keep_temp'], params['advanced_ffc_sinFFC'], params['advanced_ffc_method'], params['advanced_ffc_eigen_pco_reps'], params['advanced_ffc_eigen_pco_downsample'], params['advanced_ffc_downsample'], params['main_config_common_flats_darks'], params['main_config_darks_path'], params['main_config_flats_path'], params['main_config_flats2_checkbox'], params['main_config_flats2_path'], # NLMDN Parameters params['advanced_nlmdn_apply_after_reco'], params['advanced_nlmdn_input_dir'], params['advanced_nlmdn_input_is_file'], params['advanced_nlmdn_output_dir'], params['advanced_nlmdn_save_bigtiff'], params['advanced_nlmdn_sim_search_radius'], params['advanced_nlmdn_patch_radius'], params['advanced_nlmdn_smoothing_control'], params['advanced_nlmdn_noise_std'], params['advanced_nlmdn_window'], params['advanced_nlmdn_fast'], params['advanced_nlmdn_estimate_sigma'], params['advanced_nlmdn_dry_run'], # Advanced Parameters params['advanced_advtofu_extended_settings'], params['advanced_advtofu_lamino_angle'], params['advanced_adv_tofu_z_axis_rotation'], params['advanced_advtofu_center_position_z'], params['advanced_advtofu_y_axis_rotation'], params['advanced_advtofu_aux_ffc_dark_scale'], params['advanced_advtofu_aux_ffc_flat_scale'], params['advanced_optimize_verbose_console'], params['advanced_optimize_slice_mem_coeff'], params['advanced_optimize_num_gpus'], params['advanced_optimize_slices_per_device'] ) execute_reconstruction(args, self.get_fdt_names()) if batch_run is False: msg = "Done. See output in terminal for details." QMessageBox.information(self, "Finished", msg) if not params['main_config_dry_run']: self.signal_reco_done.emit(params) except InvalidInputError as err: msg = "" err_arg = err.args msg += err.args[0] QMessageBox.information(self, "Invalid Input Error", msg) # NEED TO DETERMINE VALID RANGES # ALSO CHECK TYPES SOMEHOW def validate_input(self): """ Determines whether user-input values are valid """ # Search rotation: main_cor_axis_search_interval # Search in slice: main_cor_search_row_start if int(parameters.params['main_cor_search_row_start']) < 0: raise InvalidInputError("Value out of range for: Search in slice from row number") # Size of reconstructed: main_cor_recon_patch_size if int(parameters.params['main_cor_recon_patch_size']) < 0: raise InvalidInputError("Value out of range for: Size of reconstructed patch [pixel]") # Axis is in column No: main_cor_axis_column if float(parameters.params['main_cor_axis_column']) < 0: raise InvalidInputError("Value out of range for: Axis is in column No [pixel]") # Increment axis: main_cor_axis_increment_step if float(parameters.params['main_cor_axis_increment_step']) < 0: raise InvalidInputError("Value out of range for: Increment axis every reconstruction") # Threshold: main_filters_remove_spots_threshold if int(parameters.params['main_filters_remove_spots_threshold']) < 0: raise InvalidInputError("Value out of range for: Threshold (prominence of the spot) [counts]") # Spot blur: main_filters_remove_spots_blur_sigma if int(parameters.params['main_filters_remove_spots_blur_sigma']) < 0: raise InvalidInputError("Value out of range for: Spot blur. sigma [pixels]") # Sigma: e_sig_hor if int(parameters.params['main_filters_ring_removal_ufo_lpf_sigma_horizontal']) < 0: raise InvalidInputError("Value out of range for: ufo ring-removal sigma horizontal") # Sigma: e_sig_ver if int(parameters.params['main_filters_ring_removal_ufo_lpf_sigma_vertical']) < 0: raise InvalidInputError("Value out of range for: ufo ring-removal sigma vertical") # Window size: main_filters_ring_removal_sarepy_window_size if int(parameters.params['main_filters_ring_removal_sarepy_window_size']) < 0: raise InvalidInputError("Value out of range for: window size") # Wind: main_filters_ring_removal_sarepy_window if int(parameters.params['main_filters_ring_removal_sarepy_window']) < 0: raise InvalidInputError("Value out of range for: wind") # SNR: main_filters_ring_removal_sarepy_SNR if int(parameters.params['main_filters_ring_removal_sarepy_SNR']) < 0: raise InvalidInputError("Value out of range for: SNR") # Photon energy: main_pr_photon_energy if float(parameters.params['main_pr_photon_energy']) < 0: raise InvalidInputError("Value out of range for: Photon energy [keV]") # Pixel size: main_pr_pixel_size if float(parameters.params['main_pr_pixel_size']) < 0: raise InvalidInputError("Value out of range for: Pixel size [micron]") # Sample detector distance: main_pr_detector_distance if float(parameters.params['main_pr_detector_distance']) < 0: raise InvalidInputError("Value out of range for: Sample-detector distance [m]") # Delta/beta ratio: main_pr_delta_beta_ratio if int(parameters.params['main_pr_delta_beta_ratio']) < 0: raise InvalidInputError("Value out of range for: Delta/beta ratio: (try default if unsure)") # First row in projections: main_region_first_row if int(parameters.params['main_region_first_row']) < 0: raise InvalidInputError("Value out of range for: First row in projections") # Number of rows: main_region_number_rows if int(parameters.params['main_region_number_rows']) < 0: raise InvalidInputError("Value out of range for: Number of rows (ROI height)") # Reconstruct every Nth row: main_region_nth_row if int(parameters.params['main_region_nth_row']) < 0: raise InvalidInputError("Value out of range for: Reconstruct every Nth row") # Can be negative when 16-bit selected # Min value: main_region_histogram_min #if float(parameters.params['main_region_histogram_min']) < 0: # raise InvalidInputError("Value out of range for: Min value in 32-bit histogram") # Max value: main_region_histogram_max if float(parameters.params['main_region_histogram_max']) < 0: raise InvalidInputError("Value out of range for: Max value in 32-bit histogram") # x: main_region_crop_x if int(parameters.params['main_region_crop_x']) < 0: raise InvalidInputError("Value out of range for: Crop slices: x") # width: main_region_crop_width if int(parameters.params['main_region_crop_width']) < 0: raise InvalidInputError("Value out of range for: Crop slices: width") # y: main_region_crop_y if int(parameters.params['main_region_crop_y']) < 0: raise InvalidInputError("Value out of range for: Crop slices: y") # height: main_region_crop_height if int(parameters.params['main_region_crop_height']) < 0: raise InvalidInputError("Value out of range for: Crop slices: height") if int(parameters.params['advanced_ffc_eigen_pco_reps']) < 0: raise InvalidInputError("Value out of range for: Flat Field Correction: Eigen PCO Repetitions") if int(parameters.params['advanced_ffc_eigen_pco_downsample']) < 0: raise InvalidInputError("Value out of range for: Flat Field Correction: Eigen PCO Downsample") if int(parameters.params['advanced_ffc_downsample']) < 0: raise InvalidInputError("Value out of range for: Flat Field Correction: Downsample") # Can be negative value # Optional: rotate volume: main_region_rotate_volume_clock #if float(parameters.params['main_region_rotate_volume_clock']) < 0: # raise InvalidInputError("Value out of range for: Optional: rotate volume clock by [deg]") #TODO ADD CHECKING NLMDN SETTINGS #TODO ADD CHECKING FOR ADVANCED SETTINGS ''' if int(parameters.params['e_adv_rotation_range']) < 0: raise InvalidInputError("Advanced: Rotation range must be greater than or equal to zero") if float(parameters.params['advanced_advtofu_lamino_angle']) < 0 or float(parameters.params['advanced_advtofu_lamino_angle']) > 90: raise InvalidInputError("Advanced: Lamino angle must be a float between 0 and 90") if float(parameters.params['advanced_optimize_slice_mem_coeff']) < 0 or float(parameters.params['advanced_optimize_slice_mem_coeff']) > 1: raise InvalidInputError("Advanced: Slice memory coefficient must be between 0 and 1") ''' def get_fdt_names(self): DIRTYP = [] for i in self.e_DIRTYP: DIRTYP.append(i) LOG.debug("Result of get_fdt_names") LOG.debug(DIRTYP) return DIRTYP class tk_args(): def __init__(self, main_config_input_dir, main_config_temp_dir, main_config_output_dir, main_config_save_multipage_tiff, main_cor_axis_search_method, main_cor_axis_search_interval, main_cor_search_row_start, main_cor_recon_patch_size, main_cor_axis_column, main_cor_axis_increment_step, main_filters_remove_spots, main_filters_remove_spots_threshold, main_filters_remove_spots_blur_sigma, main_filters_ring_removal, main_filters_ring_removal_ufo_lpf, main_filters_ring_removal_ufo_lpf_1d_or_2d, main_filters_ring_removal_ufo_lpf_sigma_horizontal, main_filters_ring_removal_ufo_lpf_sigma_vertical, main_filters_ring_removal_sarepy_window_size, main_filters_ring_removal_sarepy_wide, main_filters_ring_removal_sarepy_window, main_filters_ring_removal_sarepy_SNR, main_pr_phase_retrieval, main_pr_photon_energy, main_pr_pixel_size, main_pr_detector_distance, main_pr_delta_beta_ratio, main_region_select_rows, main_region_first_row, main_region_number_rows, main_region_nth_row, main_region_clip_histogram, main_region_bit_depth, main_region_histogram_min, main_region_histogram_max, main_config_preprocess, main_config_preprocess_command, main_region_rotate_volume_clock, main_region_crop_slices, main_region_crop_x, main_region_crop_width, main_region_crop_y, main_region_crop_height, main_config_dry_run, main_config_save_params, main_config_keep_temp, advanced_ffc_sinFFC, advanced_ffc_method, advanced_ffc_eigen_pco_reps, advanced_ffc_eigen_pco_downsample, advanced_ffc_downsample, main_config_common_flats_darks, main_config_darks_path, main_config_flats_path, main_config_flats2_checkbox, main_config_flats2_path, advanced_nlmdn_apply_after_reco, advanced_nlmdn_input_dir, advanced_nlmdn_input_is_file, advanced_nlmdn_output_dir, advanced_nlmdn_save_bigtiff, advanced_nlmdn_sim_search_radius, advanced_nlmdn_patch_radius, advanced_nlmdn_smoothing_control, advanced_nlmdn_noise_std, advanced_nlmdn_window, advanced_nlmdn_fast, advanced_nlmdn_estimate_sigma, advanced_nlmdn_dry_run, advanced_advtofu_extended_settings, advanced_advtofu_lamino_angle, advanced_adv_tofu_z_axis_rotation, advanced_advtofu_center_position_z, advanced_advtofu_y_axis_rotation, advanced_advtofu_aux_ffc_dark_scale, advanced_advtofu_aux_ffc_flat_scale, advanced_optimize_verbose_console, advanced_optimize_slice_mem_coeff, advanced_optimize_num_gpus, advanced_optimize_slices_per_device): self.args={} # PATHS self.args['main_config_input_dir']=str(main_config_input_dir) setattr(self,'main_config_input_dir',self.args['main_config_input_dir']) self.args['main_config_output_dir']=str(main_config_output_dir) setattr(self,'main_config_output_dir',self.args['main_config_output_dir']) self.args['main_config_temp_dir']=str(main_config_temp_dir) setattr(self,'main_config_temp_dir',self.args['main_config_temp_dir']) self.args['main_config_save_multipage_tiff']=bool(main_config_save_multipage_tiff) setattr(self,'main_config_save_multipage_tiff',self.args['main_config_save_multipage_tiff']) # center of rotation parameters self.args['main_cor_axis_search_method']=int(main_cor_axis_search_method) setattr(self,'main_cor_axis_search_method',self.args['main_cor_axis_search_method']) self.args['main_cor_axis_search_interval']=str(main_cor_axis_search_interval) setattr(self,'main_cor_axis_search_interval',self.args['main_cor_axis_search_interval']) self.args['main_cor_recon_patch_size']=int(main_cor_recon_patch_size) setattr(self,'main_cor_recon_patch_size',self.args['main_cor_recon_patch_size']) self.args['main_cor_search_row_start']=int(main_cor_search_row_start) setattr(self,'main_cor_search_row_start',self.args['main_cor_search_row_start']) self.args['main_cor_axis_column']=float(main_cor_axis_column) setattr(self,'main_cor_axis_column',self.args['main_cor_axis_column']) self.args['main_cor_axis_increment_step']=float(main_cor_axis_increment_step) setattr(self,'main_cor_axis_increment_step',self.args['main_cor_axis_increment_step']) #ring removal self.args['main_filters_remove_spots']=bool(main_filters_remove_spots) setattr(self,'main_filters_remove_spots',self.args['main_filters_remove_spots']) self.args['main_filters_remove_spots_threshold']=int(main_filters_remove_spots_threshold) setattr(self,'main_filters_remove_spots_threshold', self.args['main_filters_remove_spots_threshold']) self.args['main_filters_remove_spots_blur_sigma']=int(main_filters_remove_spots_blur_sigma) setattr(self,'main_filters_remove_spots_blur_sigma',self.args['main_filters_remove_spots_blur_sigma']) self.args['main_filters_ring_removal']=bool(main_filters_ring_removal) setattr(self,'main_filters_ring_removal',self.args['main_filters_ring_removal']) self.args['main_filters_ring_removal_ufo_lpf'] = bool(main_filters_ring_removal_ufo_lpf) setattr(self, 'main_filters_ring_removal_ufo_lpf', self.args['main_filters_ring_removal_ufo_lpf']) self.args['main_filters_ring_removal_ufo_lpf_1d_or_2d'] = bool(main_filters_ring_removal_ufo_lpf_1d_or_2d) setattr(self, 'main_filters_ring_removal_ufo_lpf_1d_or_2d', self.args['main_filters_ring_removal_ufo_lpf_1d_or_2d']) self.args['main_filters_ring_removal_ufo_lpf_sigma_horizontal'] = int(main_filters_ring_removal_ufo_lpf_sigma_horizontal) setattr(self,'main_filters_ring_removal_ufo_lpf_sigma_horizontal',self.args['main_filters_ring_removal_ufo_lpf_sigma_horizontal']) self.args['main_filters_ring_removal_ufo_lpf_sigma_vertical'] = int(main_filters_ring_removal_ufo_lpf_sigma_vertical) setattr(self, 'main_filters_ring_removal_ufo_lpf_sigma_vertical', self.args['main_filters_ring_removal_ufo_lpf_sigma_vertical']) self.args['main_filters_ring_removal_sarepy_window_size'] = int(main_filters_ring_removal_sarepy_window_size) setattr(self, 'main_filters_ring_removal_sarepy_window_size', self.args['main_filters_ring_removal_sarepy_window_size']) self.args['main_filters_ring_removal_sarepy_wide'] = bool(main_filters_ring_removal_sarepy_wide) setattr(self, 'main_filters_ring_removal_sarepy_wide', self.args['main_filters_ring_removal_sarepy_wide']) self.args['main_filters_ring_removal_sarepy_window'] = int(main_filters_ring_removal_sarepy_window) setattr(self, 'main_filters_ring_removal_sarepy_window', self.args['main_filters_ring_removal_sarepy_window']) self.args['main_filters_ring_removal_sarepy_SNR'] = int(main_filters_ring_removal_sarepy_SNR) setattr(self, 'main_filters_ring_removal_sarepy_SNR', self.args['main_filters_ring_removal_sarepy_SNR']) # phase retrieval self.args['main_pr_phase_retrieval'] = bool(main_pr_phase_retrieval) setattr(self, 'main_pr_phase_retrieval', self.args['main_pr_phase_retrieval']) self.args['main_pr_photon_energy']=float(main_pr_photon_energy) setattr(self,'main_pr_photon_energy',self.args['main_pr_photon_energy']) self.args['main_pr_pixel_size']=float(main_pr_pixel_size)*1e-6 setattr(self,'main_pr_pixel_size',self.args['main_pr_pixel_size']) self.args['main_pr_detector_distance']=float(main_pr_detector_distance) setattr(self,'main_pr_detector_distance',self.args['main_pr_detector_distance']) self.args['main_pr_delta_beta_ratio']=np.log10(float(main_pr_delta_beta_ratio)) setattr(self,'main_pr_delta_beta_ratio',self.args['main_pr_delta_beta_ratio']) # Crop vertically self.args['main_region_select_rows']=bool(main_region_select_rows) setattr(self,'main_region_select_rows',self.args['main_region_select_rows']) self.args['main_region_first_row']=int(main_region_first_row) setattr(self,'main_region_first_row',self.args['main_region_first_row']) self.args['main_region_number_rows']=int(main_region_number_rows) setattr(self,'main_region_number_rows',self.args['main_region_number_rows']) self.args['main_region_nth_row']=int(main_region_nth_row) setattr(self,'main_region_nth_row',self.args['main_region_nth_row']) # conv to 8 bit self.args['main_region_clip_histogram']=bool(main_region_clip_histogram) setattr(self,'main_region_clip_histogram',self.args['main_region_clip_histogram']) self.args['main_region_bit_depth']=int(main_region_bit_depth) setattr(self,'main_region_bit_depth',self.args['main_region_bit_depth']) self.args['main_region_histogram_min']=float(main_region_histogram_min) setattr(self,'main_region_histogram_min',self.args['main_region_histogram_min']) self.args['main_region_histogram_max']=float(main_region_histogram_max) setattr(self,'main_region_histogram_max',self.args['main_region_histogram_max']) # preprocessing attributes self.args['main_config_preprocess']=bool(main_config_preprocess) setattr(self,'main_config_preprocess',self.args['main_config_preprocess']) self.args['main_config_preprocess_command']=main_config_preprocess_command setattr(self,'main_config_preprocess_command',self.args['main_config_preprocess_command']) # ROI in slice self.args['main_region_crop_slices']=bool(main_region_crop_slices) setattr(self,'main_region_crop_slices',self.args['main_region_crop_slices']) self.args['main_region_crop_x']=int(main_region_crop_x) setattr(self,'main_region_crop_x',self.args['main_region_crop_x']) self.args['main_region_crop_width']=int(main_region_crop_width) setattr(self,'main_region_crop_width',self.args['main_region_crop_width']) self.args['main_region_crop_y']=int(main_region_crop_y) setattr(self,'main_region_crop_y',self.args['main_region_crop_y']) self.args['main_region_crop_height']=int(main_region_crop_height) setattr(self,'main_region_crop_height',self.args['main_region_crop_height']) # Optional FBP params self.args['main_region_rotate_volume_clock']= float(main_region_rotate_volume_clock) setattr(self,'main_region_rotate_volume_clock',self.args['main_region_rotate_volume_clock']) # misc settings self.args['main_config_dry_run']=bool(main_config_dry_run) setattr(self,'main_config_dry_run',self.args['main_config_dry_run']) self.args['main_config_save_params']=bool(main_config_save_params) setattr(self,'main_config_save_params',self.args['main_config_save_params']) self.args['main_config_keep_temp']=bool(main_config_keep_temp) setattr(self,'main_config_keep_temp',self.args['main_config_keep_temp']) #sinFFC settings self.args['advanced_ffc_sinFFC']=bool(advanced_ffc_sinFFC) setattr(self,'advanced_ffc_sinFFC', self.args['advanced_ffc_sinFFC']) self.args['advanced_ffc_method'] = str(advanced_ffc_method) setattr(self, 'advanced_ffc_method', self.args['advanced_ffc_method']) self.args['advanced_ffc_eigen_pco_reps']=int(advanced_ffc_eigen_pco_reps) setattr(self, 'advanced_ffc_eigen_pco_reps', self.args['advanced_ffc_eigen_pco_reps']) self.args['advanced_ffc_eigen_pco_downsample'] = int(advanced_ffc_eigen_pco_downsample) setattr(self, 'advanced_ffc_eigen_pco_downsample', self.args['advanced_ffc_eigen_pco_downsample']) self.args['advanced_ffc_downsample'] = int(advanced_ffc_downsample) setattr(self, 'advanced_ffc_downsample', self.args['advanced_ffc_downsample']) #Settings for using flats/darks across multiple experiments self.args['main_config_common_flats_darks'] = bool(main_config_common_flats_darks) setattr(self, 'main_config_common_flats_darks', self.args['main_config_common_flats_darks']) self.args['main_config_darks_path'] = str(main_config_darks_path) setattr(self, 'main_config_darks_path', self.args['main_config_darks_path']) self.args['main_config_flats_path'] = str(main_config_flats_path) setattr(self, 'main_config_flats_path', self.args['main_config_flats_path']) self.args['main_config_flats2_checkbox'] = bool(main_config_flats2_checkbox) setattr(self, 'main_config_flats2_checkbox', self.args['main_config_flats2_checkbox']) self.args['main_config_flats2_path'] = str(main_config_flats2_path) setattr(self, 'main_config_flats2_path', self.args['main_config_flats2_path']) #NLMDN Settings self.args['advanced_nlmdn_apply_after_reco'] = bool(advanced_nlmdn_apply_after_reco) setattr(self, 'advanced_nlmdn_apply_after_reco', self.args['advanced_nlmdn_apply_after_reco']) self.args['advanced_nlmdn_input_dir'] = str(advanced_nlmdn_input_dir) setattr(self, 'advanced_nlmdn_input_dir', self.args['advanced_nlmdn_input_dir']) self.args['advanced_nlmdn_input_is_file'] = bool(advanced_nlmdn_input_is_file) setattr(self, 'advanced_nlmdn_input_is_file', self.args['advanced_nlmdn_input_is_file']) self.args['advanced_nlmdn_output_dir'] = str(advanced_nlmdn_output_dir) setattr(self, 'advanced_nlmdn_output_dir', self.args['advanced_nlmdn_output_dir']) self.args['advanced_nlmdn_save_bigtiff'] = bool(advanced_nlmdn_save_bigtiff) setattr(self, 'advanced_nlmdn_save_bigtiff', self.args['advanced_nlmdn_save_bigtiff']) self.args['advanced_nlmdn_sim_search_radius'] = str(advanced_nlmdn_sim_search_radius) setattr(self, 'advanced_nlmdn_sim_search_radius', self.args['advanced_nlmdn_sim_search_radius']) self.args['advanced_nlmdn_patch_radius'] = str(advanced_nlmdn_patch_radius) setattr(self, 'advanced_nlmdn_patch_radius', self.args['advanced_nlmdn_patch_radius']) self.args['advanced_nlmdn_smoothing_control'] = str(advanced_nlmdn_smoothing_control) setattr(self, 'advanced_nlmdn_smoothing_control', self.args['advanced_nlmdn_smoothing_control']) self.args['advanced_nlmdn_noise_std'] = str(advanced_nlmdn_noise_std) setattr(self, 'advanced_nlmdn_noise_std', self.args['advanced_nlmdn_noise_std']) self.args['advanced_nlmdn_window'] = str(advanced_nlmdn_window) setattr(self, 'advanced_nlmdn_window', self.args['advanced_nlmdn_window']) self.args['advanced_nlmdn_fast'] = bool(advanced_nlmdn_fast) setattr(self, 'advanced_nlmdn_fast', self.args['advanced_nlmdn_fast']) self.args['advanced_nlmdn_estimate_sigma'] = bool(advanced_nlmdn_estimate_sigma) setattr(self, 'advanced_nlmdn_estimate_sigma', self.args['advanced_nlmdn_estimate_sigma']) self.args['advanced_nlmdn_dry_run'] = bool(advanced_nlmdn_dry_run) setattr(self, 'advanced_nlmdn_dry_run', self.args['advanced_nlmdn_dry_run']) #Advanced Settings self.args['advanced_advtofu_extended_settings'] = bool(advanced_advtofu_extended_settings) setattr(self, 'advanced_advtofu_extended_settings', self.args['advanced_advtofu_extended_settings']) self.args['advanced_advtofu_lamino_angle'] = str(advanced_advtofu_lamino_angle) setattr(self, 'advanced_advtofu_lamino_angle', self.args['advanced_advtofu_lamino_angle']) self.args['advanced_adv_tofu_z_axis_rotation'] = str(advanced_adv_tofu_z_axis_rotation) setattr(self, 'advanced_adv_tofu_z_axis_rotation', self.args['advanced_adv_tofu_z_axis_rotation']) self.args['advanced_advtofu_center_position_z'] = str(advanced_advtofu_center_position_z) setattr(self, 'advanced_advtofu_center_position_z', self.args['advanced_advtofu_center_position_z']) self.args['advanced_advtofu_y_axis_rotation'] = str(advanced_advtofu_y_axis_rotation) setattr(self, 'advanced_advtofu_y_axis_rotation', self.args['advanced_advtofu_y_axis_rotation']) self.args['advanced_advtofu_aux_ffc_dark_scale'] = str(advanced_advtofu_aux_ffc_dark_scale) setattr(self, 'advanced_advtofu_aux_ffc_dark_scale', self.args['advanced_advtofu_aux_ffc_dark_scale']) self.args['advanced_advtofu_aux_ffc_flat_scale'] = str(advanced_advtofu_aux_ffc_flat_scale) setattr(self, 'advanced_advtofu_aux_ffc_flat_scale', self.args['advanced_advtofu_aux_ffc_flat_scale']) #Optimization self.args['advanced_optimize_verbose_console'] = bool(advanced_optimize_verbose_console) setattr(self, 'advanced_optimize_verbose_console', self.args['advanced_optimize_verbose_console']) self.args['advanced_optimize_slice_mem_coeff'] = str(advanced_optimize_slice_mem_coeff) setattr(self, 'advanced_optimize_slice_mem_coeff', self.args['advanced_optimize_slice_mem_coeff']) self.args['advanced_optimize_num_gpus'] = str(advanced_optimize_num_gpus) setattr(self, 'advanced_optimize_num_gpus', self.args['advanced_optimize_num_gpus']) self.args['advanced_optimize_slices_per_device'] = str(advanced_optimize_slices_per_device) setattr(self, 'advanced_optimize_slices_per_device', self.args['advanced_optimize_slices_per_device']) LOG.debug("Contents of arg dict: ") LOG.debug(self.args.items()) class InvalidInputError(Exception): """ Error to be raised when input values from GUI are out of range or invalid """ tofu-0.12.0/tofu/ez/GUI/Main/filters.py000066400000000000000000000311771423713721100175060ustar00rootroot00000000000000import logging from PyQt5.QtWidgets import ( QButtonGroup, QGridLayout, QLabel, QRadioButton, QCheckBox, QGroupBox, QLineEdit, ) from PyQt5.QtCore import Qt import tofu.ez.params as parameters LOG = logging.getLogger(__name__) class FiltersGroup(QGroupBox): """ Filter settings """ def __init__(self): super().__init__() self.setTitle("Filters") self.setStyleSheet("QGroupBox {color: orange;}") self.remove_spots_checkBox = QCheckBox() self.remove_spots_checkBox.setText("Remove large spots from projections") self.remove_spots_checkBox.setToolTip( "Efficiently suppresses very intense rings \n stemming from defects in scintillator" ) self.remove_spots_checkBox.stateChanged.connect(self.set_remove_spots) self.threshold_label = QLabel() self.threshold_label.setText("Threshold (prominence of the spot) [counts]") self.threshold_label.setToolTip( "Outliers which will be considered as the part of the large spot" ) self.threshold_entry = QLineEdit() self.threshold_entry.editingFinished.connect(self.set_threshold) self.spot_blur_label = QLabel() self.spot_blur_label.setText("Spot blur. sigma [pixels]") self.spot_blur_label.setToolTip( "Regulates extent of the masked region around the detected outlier" ) self.spot_blur_entry = QLineEdit() self.spot_blur_entry.editingFinished.connect(self.set_spot_blur) self.enable_RR_checkbox = QCheckBox() self.enable_RR_checkbox.setText("Enable ring removal") self.remove_spots_checkBox.setToolTip( "To suppress ring artifacts" " stemming from intensity fluctuations and detector nonlinearities" ) self.enable_RR_checkbox.stateChanged.connect(self.set_ring_removal) self.use_LPF_rButton = QRadioButton() self.use_LPF_rButton.setText("Use ufo Fourier-transform based filter") self.use_LPF_rButton.clicked.connect(self.select_rButton) self.sarepy_rButton = QRadioButton() self.sarepy_rButton.setText("Use sarepy sorting: ") self.sarepy_rButton.clicked.connect(self.select_rButton) self.sarepy_rButton.setToolTip( "Non-FFT based algorithms from \n /Nghia T. Vo et al, Opt. Express 26, 28396 (2018)" ) self.filter_rButton_group = QButtonGroup(self) self.filter_rButton_group.addButton(self.use_LPF_rButton) self.filter_rButton_group.addButton(self.sarepy_rButton) self.one_dimens_rButton = QRadioButton() self.one_dimens_rButton.setText("1D") self.one_dimens_rButton.clicked.connect(self.select_dimens_rButton) self.one_dimens_rButton.setToolTip("Only low-pass filter along the lines of sinogram") self.two_dimens_rButton = QRadioButton() self.two_dimens_rButton.setText("2D") self.two_dimens_rButton.clicked.connect(self.select_dimens_rButton) self.two_dimens_rButton.setToolTip( "Low-pass filter along the lines and high-pass filter along the columns" ) self.dimens_rButton_group = QButtonGroup(self) self.dimens_rButton_group.addButton(self.one_dimens_rButton) self.dimens_rButton_group.addButton(self.two_dimens_rButton) self.sigma_horizontal_label = QLabel() self.sigma_horizontal_label.setText("sigma horizontal") self.sigma_horizontal_label.setToolTip( "Width [pixels] of Gaussian-shaped low-pass filter in frequency domain" ) self.sigma_horizontal_entry = QLineEdit() self.sigma_horizontal_entry.editingFinished.connect(self.set_sigma_horizontal) self.sigma_vertical_label = QLabel() self.sigma_vertical_label.setText("sigma vertical") self.sigma_vertical_label.setToolTip( "Width [pixels] of Gaussian-shaped high-pass filter in frequency domain" ) self.sigma_vertical_entry = QLineEdit() self.sigma_vertical_entry.editingFinished.connect(self.set_sigma_vertical) self.wind_size_label = QLabel() self.wind_size_label.setText("window size") self.wind_size_label.setToolTip("Window size in remove_stripe_based_sorting algorithm") self.wind_size_entry = QLineEdit() self.wind_size_entry.editingFinished.connect(self.set_window_size) self.wind_size_entry.setToolTip("Typically in the range 31..51 ") self.remove_wide_checkbox = QCheckBox() self.remove_wide_checkbox.setText("Remove wide") self.remove_wide_checkbox.setToolTip("Window size in remove_large_stripe algorithm") self.remove_wide_checkbox.stateChanged.connect(self.set_remove_wide) self.remove_wide_label = QLabel() self.remove_wide_label.setText("window") self.remove_wide_label.setToolTip("Typically in the range 51..131 ") self.remove_wide_entry = QLineEdit() self.remove_wide_entry.editingFinished.connect(self.set_wind) self.SNR_label = QLabel() self.SNR_label.setText("SNR") self.SNR_label.setToolTip("SNR param in remove_large_stripe algorithm") self.SNR_entry = QLineEdit() self.SNR_entry.editingFinished.connect(self.set_SNR) self.set_layout() def set_layout(self): layout = QGridLayout() remove_spots_groupbox = QGroupBox() remove_spots_layout = QGridLayout() remove_spots_layout.addWidget(self.remove_spots_checkBox, 0, 0) remove_spots_layout.addWidget(self.threshold_label, 1, 0) remove_spots_layout.addWidget(self.threshold_entry, 1, 1, 1, 7) remove_spots_layout.addWidget(self.spot_blur_label, 2, 0) remove_spots_layout.addWidget(self.spot_blur_entry, 2, 1, 1, 7) remove_spots_groupbox.setLayout(remove_spots_layout) layout.addWidget(remove_spots_groupbox) rr_groupbox = QGroupBox() rr_layout = QGridLayout() rr_layout.addWidget(self.enable_RR_checkbox, 3, 0) rr_layout.addWidget(self.use_LPF_rButton, 4, 0) rr_layout.addWidget(self.one_dimens_rButton, 4, 1) rr_layout.addWidget(self.two_dimens_rButton, 4, 2) rr_layout.addWidget(self.sigma_horizontal_label, 4, 3, Qt.AlignRight) rr_layout.addWidget(self.sigma_horizontal_entry, 4, 4) rr_layout.addWidget(self.sigma_vertical_label, 4, 5, Qt.AlignRight) rr_layout.addWidget(self.sigma_vertical_entry, 4, 6) rr_layout.addWidget(self.sarepy_rButton, 5, 0) rr_layout.addWidget(self.wind_size_label, 5, 1) rr_layout.addWidget(self.wind_size_entry, 5, 2) rr_layout.addWidget(self.remove_wide_checkbox, 5, 3) rr_layout.addWidget(self.remove_wide_label, 5, 4, Qt.AlignRight) rr_layout.addWidget(self.remove_wide_entry, 5, 5) rr_layout.addWidget(self.SNR_label, 5, 6) rr_layout.addWidget(self.SNR_entry, 5, 7) rr_groupbox.setLayout(rr_layout) layout.addWidget(rr_groupbox, 3, 0) self.setLayout(layout) def init_values(self): self.remove_wide_checkbox.setChecked(False) self.set_remove_spots() parameters.params['main_filters_remove_spots'] = False self.threshold_entry.setText( str(parameters.params['main_filters_remove_spots_threshold']) ) self.spot_blur_entry.setText( str(parameters.params['main_filters_remove_spots_blur_sigma']) ) self.enable_RR_checkbox.setChecked(False) self.set_ring_removal() parameters.params['main_filters_ring_removal'] = False self.use_LPF_rButton.setChecked(True) self.select_rButton() self.sarepy_rButton.setChecked(False) self.two_dimens_rButton.setChecked(True) parameters.params['main_filters_ring_removal_ufo_lpf_1d_or_2d'] = False self.sigma_horizontal_entry.setText( str(parameters.params['main_filters_ring_removal_ufo_lpf_sigma_horizontal']) ) self.sigma_vertical_entry.setText( str(parameters.params['main_filters_ring_removal_ufo_lpf_sigma_vertical']) ) self.wind_size_entry.setText("21") self.remove_wide_checkbox.setChecked(False) parameters.params['main_filters_ring_removal_sarepy_wide'] = False self.remove_wide_entry.setText("91") self.SNR_entry.setText("3") def set_values_from_params(self): self.remove_spots_checkBox.setChecked(parameters.params['main_filters_remove_spots']) self.threshold_entry.setText(str(parameters.params['main_filters_remove_spots_threshold'])) self.spot_blur_entry.setText(str(parameters.params['main_filters_remove_spots_blur_sigma'])) self.enable_RR_checkbox.setChecked(parameters.params['main_filters_ring_removal']) if parameters.params['main_filters_ring_removal_ufo_lpf'] == True: self.use_LPF_rButton.setChecked(True) elif parameters.params['main_filters_ring_removal_ufo_lpf'] == False: self.use_LPF_rButton.setChecked(False) if parameters.params['main_filters_ring_removal_ufo_lpf_1d_or_2d'] == True: self.one_dimens_rButton.setChecked(True) self.two_dimens_rButton.setChecked(False) elif parameters.params['main_filters_ring_removal_ufo_lpf_1d_or_2d'] == False: self.two_dimens_rButton.setChecked(True) self.one_dimens_rButton.setChecked(False) self.sigma_horizontal_entry.setText(str(parameters.params['main_filters_ring_removal_ufo_lpf_sigma_horizontal'])) self.sigma_vertical_entry.setText(str(parameters.params['main_filters_ring_removal_ufo_lpf_sigma_vertical'])) self.wind_size_entry.setText(str(parameters.params['main_filters_ring_removal_sarepy_window_size'])) self.remove_wide_checkbox.setChecked(parameters.params['main_filters_ring_removal_sarepy_wide']) self.remove_wide_entry.setText(str(parameters.params['main_filters_ring_removal_sarepy_window'])) self.SNR_entry.setText(str(parameters.params['main_filters_ring_removal_sarepy_SNR'])) def set_remove_spots(self): LOG.debug("Remove large spots:" + str(self.remove_spots_checkBox.isChecked())) parameters.params['main_filters_remove_spots'] = bool(self.remove_spots_checkBox.isChecked()) def set_threshold(self): LOG.debug(self.threshold_entry.text()) parameters.params['main_filters_remove_spots_threshold'] = str(self.threshold_entry.text()) def set_spot_blur(self): LOG.debug(self.spot_blur_entry.text()) parameters.params['main_filters_remove_spots_blur_sigma'] = str(self.spot_blur_entry.text()) def set_ring_removal(self): LOG.debug("RR: " + str(self.enable_RR_checkbox.isChecked())) parameters.params['main_filters_ring_removal'] = bool(self.enable_RR_checkbox.isChecked()) def select_rButton(self): if self.use_LPF_rButton.isChecked(): LOG.debug("Use LPF") parameters.params['main_filters_ring_removal_ufo_lpf'] = bool(True) elif self.sarepy_rButton.isChecked(): LOG.debug("Use Sarepy") parameters.params['main_filters_ring_removal_ufo_lpf'] = bool(False) def select_dimens_rButton(self): if self.one_dimens_rButton.isChecked(): LOG.debug("One dimension") parameters.params['main_filters_ring_removal_ufo_lpf_1d_or_2d'] = bool(True) elif self.two_dimens_rButton.isChecked(): LOG.debug("Two dimensions") parameters.params['main_filters_ring_removal_ufo_lpf_1d_or_2d'] = bool(False) def set_sigma_horizontal(self): LOG.debug(self.sigma_horizontal_entry.text()) parameters.params['main_filters_ring_removal_ufo_lpf_sigma_horizontal'] = \ str(self.sigma_horizontal_entry.text()) def set_sigma_vertical(self): LOG.debug(self.sigma_vertical_entry.text()) parameters.params['main_filters_ring_removal_ufo_lpf_sigma_vertical'] = \ str(self.sigma_vertical_entry.text()) def set_window_size(self): LOG.debug(self.wind_size_entry.text()) parameters.params['main_filters_ring_removal_sarepy_window_size'] = str(self.wind_size_entry.text()) def set_remove_wide(self): LOG.debug("Wide: " + str(self.remove_wide_checkbox.isChecked())) parameters.params['main_filters_ring_removal_sarepy_wide'] = bool(self.remove_wide_checkbox.isChecked()) def set_wind(self): LOG.debug(self.remove_wide_entry.text()) parameters.params['main_filters_ring_removal_sarepy_window'] = str(self.remove_wide_entry.text()) def set_SNR(self): LOG.debug(self.SNR_entry.text()) parameters.params['main_filters_ring_removal_sarepy_SNR'] = str(self.SNR_entry.text())tofu-0.12.0/tofu/ez/GUI/Main/phase_retrieval.py000066400000000000000000000102311423713721100211770ustar00rootroot00000000000000import logging from PyQt5.QtWidgets import QGridLayout, QLabel, QGroupBox, QLineEdit, QCheckBox import tofu.ez.params as parameters LOG = logging.getLogger(__name__) class PhaseRetrievalGroup(QGroupBox): """ Phase Retrieval settings """ def __init__(self): super().__init__() self.setTitle("Phase Retrieval") self.setStyleSheet("QGroupBox {color: blue;}") self.enable_PR_checkBox = QCheckBox() self.enable_PR_checkBox.setText("Enable Paganin/TIE phase retrieval") self.enable_PR_checkBox.stateChanged.connect(self.set_PR) self.photon_energy_label = QLabel() self.photon_energy_label.setText("Photon energy [keV]") self.photon_energy_entry = QLineEdit() self.photon_energy_entry.editingFinished.connect(self.set_photon_energy) self.photon_energy_entry.setStyleSheet("background-color:white") self.pixel_size_label = QLabel() self.pixel_size_label.setText("Pixel size [micron]") self.pixel_size_entry = QLineEdit() self.pixel_size_entry.editingFinished.connect(self.set_pixel_size) self.pixel_size_entry.setStyleSheet("background-color:white") self.detector_distance_label = QLabel() self.detector_distance_label.setText("Sample-detector distance [m]") self.detector_distance_entry = QLineEdit() self.detector_distance_entry.editingFinished.connect(self.set_detector_distance) self.detector_distance_entry.setStyleSheet("background-color:white") self.delta_beta_ratio_label = QLabel() self.delta_beta_ratio_label.setText("Delta/beta ratio: (try default if unsure)") self.delta_beta_ratio_entry = QLineEdit() self.delta_beta_ratio_entry.editingFinished.connect(self.set_delta_beta) self.delta_beta_ratio_entry.setStyleSheet("background-color:white") self.set_layout() def set_layout(self): layout = QGridLayout() layout.addWidget(self.enable_PR_checkBox, 0, 0) layout.addWidget(self.photon_energy_label, 1, 0) layout.addWidget(self.photon_energy_entry, 1, 1) layout.addWidget(self.pixel_size_label, 2, 0) layout.addWidget(self.pixel_size_entry, 2, 1) layout.addWidget(self.detector_distance_label, 3, 0) layout.addWidget(self.detector_distance_entry, 3, 1) layout.addWidget(self.delta_beta_ratio_label, 4, 0) layout.addWidget(self.delta_beta_ratio_entry, 4, 1) self.setLayout(layout) def init_values(self): self.enable_PR_checkBox.setChecked(False) parameters.params['main_pr_phase_retrieval'] = False self.photon_energy_entry.setText("20") self.pixel_size_entry.setText("3.6") self.detector_distance_entry.setText("0.1") self.delta_beta_ratio_entry.setText("200") def set_values_from_params(self): self.enable_PR_checkBox.setChecked(parameters.params['main_pr_phase_retrieval']) self.photon_energy_entry.setText(str(parameters.params['main_pr_photon_energy'])) self.pixel_size_entry.setText(str(parameters.params['main_pr_pixel_size'])) self.detector_distance_entry.setText(str(parameters.params['main_pr_detector_distance'])) self.delta_beta_ratio_entry.setText(str(parameters.params['main_pr_delta_beta_ratio'])) def set_PR(self): LOG.debug("PR: " + str(self.enable_PR_checkBox.isChecked())) parameters.params['main_pr_phase_retrieval'] = bool(self.enable_PR_checkBox.isChecked()) def set_photon_energy(self): LOG.debug(self.photon_energy_entry.text()) parameters.params['main_pr_photon_energy'] = str(self.photon_energy_entry.text()) def set_pixel_size(self): LOG.debug(self.pixel_size_entry.text()) parameters.params['main_pr_pixel_size'] = str(self.pixel_size_entry.text()) def set_detector_distance(self): LOG.debug(self.detector_distance_entry.text()) parameters.params['main_pr_detector_distance'] = str(self.detector_distance_entry.text()) def set_delta_beta(self): LOG.debug(self.delta_beta_ratio_entry.text()) parameters.params['main_pr_delta_beta_ratio'] = str(self.delta_beta_ratio_entry.text())tofu-0.12.0/tofu/ez/GUI/Main/region_and_histogram.py000066400000000000000000000245731423713721100222220ustar00rootroot00000000000000import logging from PyQt5.QtWidgets import QGridLayout, QRadioButton, QLabel, QGroupBox, QLineEdit, QCheckBox from PyQt5.QtCore import Qt import tofu.ez.params as parameters LOG = logging.getLogger(__name__) class ROIandHistGroup(QGroupBox): """ Binning settings """ def __init__(self): super().__init__() self.setTitle("Region of Interest and Histogram Settings") self.setStyleSheet("QGroupBox {color: red;}") self.select_rows_checkbox = QCheckBox() self.select_rows_checkbox.setText("Select rows which will be reconstructed") self.select_rows_checkbox.stateChanged.connect(self.set_select_rows) self.first_row_label = QLabel() self.first_row_label.setText("First row in projections") self.first_row_label.setToolTip("Counting from the top") self.first_row_entry = QLineEdit() self.first_row_entry.editingFinished.connect(self.set_first_row) self.num_rows_label = QLabel() self.num_rows_label.setText("Number of rows (ROI height)") self.num_rows_entry = QLineEdit() self.num_rows_entry.editingFinished.connect(self.set_num_rows) self.nth_row_label = QLabel() self.nth_row_label.setText("Step (reconstruct every Nth row)") self.nth_row_entry = QLineEdit() self.nth_row_entry.editingFinished.connect(self.set_reco_nth_rows) self.clip_histo_checkbox = QCheckBox() self.clip_histo_checkbox.setText("Clip histogram and save slices in") self.clip_histo_checkbox.stateChanged.connect(self.set_clip_histo) self.eight_bit_rButton = QRadioButton() self.eight_bit_rButton.setText("8-bit") self.eight_bit_rButton.setChecked(True) self.eight_bit_rButton.clicked.connect(self.set_bitdepth) self.sixteen_bit_rButton = QRadioButton() self.sixteen_bit_rButton.setText("16-bit") self.sixteen_bit_rButton.clicked.connect(self.set_bitdepth) self.min_val_label = QLabel() self.min_val_label.setText("Min value in 32-bit histogram") self.min_val_entry = QLineEdit() self.min_val_entry.editingFinished.connect(self.set_min_val) self.max_val_label = QLabel() self.max_val_label.setText("Max value in 32-bit histogram") self.max_val_entry = QLineEdit() self.max_val_entry.editingFinished.connect(self.set_max_val) self.crop_slices_checkbox = QCheckBox() self.crop_slices_checkbox.setText("Crop slices") self.crop_slices_checkbox.setToolTip("Crop slices in the reconstruction plane \n" "(x,y) - top left corner of selection \n" "(width, height) - size of selection") self.crop_slices_checkbox.stateChanged.connect(self.set_crop_slices) self.x_val_label = QLabel() self.x_val_label.setText("x") self.x_val_label.setToolTip("First column (counting from left)") self.x_val_entry = QLineEdit() self.x_val_entry.editingFinished.connect(self.set_x) self.width_val_label = QLabel() self.width_val_label.setText("width") self.width_val_entry = QLineEdit() self.width_val_entry.editingFinished.connect(self.set_width) self.y_val_label = QLabel() self.y_val_label.setText("y") self.y_val_label.setToolTip("First row (counting from top)") self.y_val_entry = QLineEdit() self.y_val_entry.editingFinished.connect(self.set_y) self.height_val_label = QLabel() self.height_val_label.setText("height") self.height_val_entry = QLineEdit() self.height_val_entry.editingFinished.connect(self.set_height) self.rotate_vol_label = QLabel() self.rotate_vol_label.setText("Rotate volume clockwise by [deg]") self.rotate_vol_entry = QLineEdit() self.rotate_vol_entry.editingFinished.connect(self.set_rotate_volume) # self.setStyleSheet('background-color:Azure') self.set_layout() def set_layout(self): """ Sets the layout of buttons, labels, etc. for binning group """ layout = QGridLayout() layout.addWidget(self.select_rows_checkbox, 0, 0) layout.addWidget(self.first_row_label, 1, 0) layout.addWidget(self.first_row_entry, 1, 1, 1, 8) layout.addWidget(self.num_rows_label, 2, 0) layout.addWidget(self.num_rows_entry, 2, 1, 1, 8) layout.addWidget(self.nth_row_label, 3, 0) layout.addWidget(self.nth_row_entry, 3, 1, 1, 8) layout.addWidget(self.clip_histo_checkbox, 4, 0) layout.addWidget(self.eight_bit_rButton, 4, 1) layout.addWidget(self.sixteen_bit_rButton, 4, 2) layout.addWidget(self.min_val_label, 5, 0) layout.addWidget(self.min_val_entry, 5, 1, 1, 8) layout.addWidget(self.max_val_label, 6, 0) layout.addWidget(self.max_val_entry, 6, 1, 1, 8) layout.addWidget(self.crop_slices_checkbox, 7, 0) layout.addWidget(self.x_val_label, 7, 1)#, Qt.AlignRight) layout.addWidget(self.x_val_entry, 7, 2) layout.addWidget(self.width_val_label, 7, 3)#, Qt.AlignRight) layout.addWidget(self.width_val_entry, 7, 4) layout.addWidget(self.y_val_label, 7, 5) layout.addWidget(self.y_val_entry, 7, 6) layout.addWidget(self.height_val_label, 7, 7) layout.addWidget(self.height_val_entry, 7, 8) layout.addWidget(self.rotate_vol_label, 8, 0) layout.addWidget(self.rotate_vol_entry, 8, 1, 1, 8) self.setLayout(layout) def init_values(self): self.select_rows_checkbox.setChecked(False) parameters.params['main_region_select_rows'] = False self.first_row_entry.setText("100") self.num_rows_entry.setText("200") self.nth_row_entry.setText("20") self.clip_histo_checkbox.setChecked(False) parameters.params['main_region_clip_histogram'] = False self.eight_bit_rButton.setChecked(True) parameters.params['main_region_bit_depth'] = str(8) self.min_val_entry.setText("0.0") self.max_val_entry.setText("0.0") self.crop_slices_checkbox.setChecked(False) parameters.params['main_region_crop_slices'] = False self.x_val_entry.setText("0") self.width_val_entry.setText("0") self.y_val_entry.setText("0") self.height_val_entry.setText("0") self.rotate_vol_entry.setText("0.0") def set_values_from_params(self): self.select_rows_checkbox.setChecked(parameters.params['main_region_select_rows']) self.first_row_entry.setText(str(parameters.params['main_region_first_row'])) self.num_rows_entry.setText(str(parameters.params['main_region_number_rows'])) self.nth_row_entry.setText(str(parameters.params['main_region_nth_row'])) self.clip_histo_checkbox.setChecked(parameters.params['main_region_clip_histogram']) if int(parameters.params['main_region_bit_depth']) == 8: self.eight_bit_rButton.setChecked(True) self.sixteen_bit_rButton.setChecked(False) elif int(parameters.params['main_region_bit_depth']) == 16: self.eight_bit_rButton.setChecked(False) self.sixteen_bit_rButton.setChecked(True) self.min_val_entry.setText(str(parameters.params['main_region_histogram_min'])) self.max_val_entry.setText(str(parameters.params['main_region_histogram_max'])) self.crop_slices_checkbox.setChecked(parameters.params['main_region_crop_slices']) self.x_val_entry.setText(str(parameters.params['main_region_crop_x'])) self.width_val_entry.setText(str(parameters.params['main_region_crop_width'])) self.y_val_entry.setText(str(parameters.params['main_region_crop_y'])) self.height_val_entry.setText(str(parameters.params['main_region_crop_height'])) self.rotate_vol_entry.setText(str(parameters.params['main_region_rotate_volume_clock'])) def set_select_rows(self): LOG.debug("Select rows: " + str(self.select_rows_checkbox.isChecked())) parameters.params['main_region_select_rows'] = bool(self.select_rows_checkbox.isChecked()) def set_first_row(self): LOG.debug(self.first_row_entry.text()) parameters.params['main_region_first_row'] = str(self.first_row_entry.text()) def set_num_rows(self): LOG.debug(self.num_rows_entry.text()) parameters.params['main_region_number_rows'] = str(self.num_rows_entry.text()) def set_reco_nth_rows(self): LOG.debug(self.nth_row_entry.text()) parameters.params['main_region_nth_row'] = str(self.nth_row_entry.text()) def set_clip_histo(self): LOG.debug("Clip histo: " + str(self.clip_histo_checkbox.isChecked())) parameters.params['main_region_clip_histogram'] = bool(self.clip_histo_checkbox.isChecked()) def set_bitdepth(self): if self.eight_bit_rButton.isChecked(): LOG.debug("8 bit") parameters.params['main_region_bit_depth'] = str(8) elif self.sixteen_bit_rButton.isChecked(): LOG.debug("16 bit") parameters.params['main_region_bit_depth'] = str(16) def set_min_val(self): LOG.debug(self.min_val_entry.text()) parameters.params['main_region_histogram_min'] = str(self.min_val_entry.text()) def set_max_val(self): LOG.debug(self.max_val_entry.text()) parameters.params['main_region_histogram_max'] = str(self.max_val_entry.text()) def set_crop_slices(self): LOG.debug("Crop slices: " + str(self.crop_slices_checkbox.isChecked())) parameters.params['main_region_crop_slices'] = bool(self.crop_slices_checkbox.isChecked()) def set_x(self): LOG.debug(self.x_val_entry.text()) parameters.params['main_region_crop_x'] = str(self.x_val_entry.text()) def set_width(self): LOG.debug(self.width_val_entry.text()) parameters.params['main_region_crop_width'] = str(self.width_val_entry.text()) def set_y(self): LOG.debug(self.y_val_entry.text()) parameters.params['main_region_crop_y'] = str(self.y_val_entry.text()) def set_height(self): LOG.debug(self.height_val_entry.text()) parameters.params['main_region_crop_height'] = str(self.height_val_entry.text()) def set_rotate_volume(self): LOG.debug(self.rotate_vol_entry.text()) parameters.params["main_region_rotate_volume_clock"] = str(self.rotate_vol_entry.text()) tofu-0.12.0/tofu/ez/GUI/Stitch_tools_tab/000077500000000000000000000000001423713721100200735ustar00rootroot00000000000000tofu-0.12.0/tofu/ez/GUI/Stitch_tools_tab/__init__.py000066400000000000000000000000001423713721100221720ustar00rootroot00000000000000tofu-0.12.0/tofu/ez/GUI/Stitch_tools_tab/ez_360_multi_stitch_qt.py000066400000000000000000000516641423713721100247630ustar00rootroot00000000000000from PyQt5.QtWidgets import ( QGroupBox, QPushButton, QCheckBox, QLabel, QLineEdit, QGridLayout, QFileDialog, QMessageBox, ) from PyQt5.QtCore import pyqtSignal import logging from shutil import rmtree import os import getpass import yaml from tofu.ez.Helpers.stitch_funcs import main_360_mp_depth2 from tofu.ez.GUI.message_dialog import warning_message # Params import tofu.ez.params as params LOG = logging.getLogger(__name__) class MultiStitch360Group(QGroupBox): get_fdt_names_on_stitch_pressed = pyqtSignal() def __init__(self): super().__init__() self.setTitle("Batch horizontal stitching of half-acquistion mode data sets") self.setStyleSheet('QGroupBox {color: red;}') self.input_dir_button = QPushButton("Select input directory") self.input_dir_button.clicked.connect(self.input_button_pressed) self.input_dir_entry = QLineEdit() self.input_dir_entry.editingFinished.connect(self.set_input_entry) self.temp_dir_button = QPushButton("Select temporary directory - default value recommended") self.temp_dir_button.clicked.connect(self.temp_button_pressed) self.temp_dir_entry = QLineEdit() self.temp_dir_entry.editingFinished.connect(self.set_temp_entry) self.output_dir_button = QPushButton("Directory to save stitched images") self.output_dir_button.clicked.connect(self.output_button_pressed) self.output_dir_entry = QLineEdit() self.output_dir_entry.editingFinished.connect(self.set_output_entry) self.crop_checkbox = QCheckBox("Crop all projections to match the width of smallest stitched projection") self.crop_checkbox.clicked.connect(self.set_crop_projections_checkbox) self.axis_bottom_label = QLabel() self.axis_bottom_label.setText("Axis of Rotation (Dir 00):") self.axis_bottom_entry = QLineEdit() self.axis_bottom_entry.editingFinished.connect(self.set_axis_bottom) self.axis_top_label = QLabel("Axis of Rotation (Dir 0N):") self.axis_group = QGroupBox("Enter axis of rotation manually") self.axis_group.clicked.connect(self.set_axis_group) self.axis_top_entry = QLineEdit() self.axis_top_entry.editingFinished.connect(self.set_axis_top) self.axis_z000_label = QLabel("Axis of Rotation (Dir 00):") self.axis_z000_entry = QLineEdit() self.axis_z000_entry.editingFinished.connect(self.set_z000) self.axis_z001_label = QLabel("Axis of Rotation (Dir 01):") self.axis_z001_entry = QLineEdit() self.axis_z001_entry.editingFinished.connect(self.set_z001) self.axis_z002_label = QLabel("Axis of Rotation (Dir 02):") self.axis_z002_entry = QLineEdit() self.axis_z002_entry.editingFinished.connect(self.set_z002) self.axis_z003_label = QLabel("Axis of Rotation (Dir 03):") self.axis_z003_entry = QLineEdit() self.axis_z003_entry.editingFinished.connect(self.set_z003) self.axis_z004_label = QLabel("Axis of Rotation (Dir 04):") self.axis_z004_entry = QLineEdit() self.axis_z004_entry.editingFinished.connect(self.set_z004) self.axis_z005_label = QLabel("Axis of Rotation (Dir 05):") self.axis_z005_entry = QLineEdit() self.axis_z005_entry.editingFinished.connect(self.set_z005) self.axis_z006_label = QLabel("Axis of Rotation (Dir 06):") self.axis_z006_entry = QLineEdit() self.axis_z006_entry.editingFinished.connect(self.set_z006) self.axis_z007_label = QLabel("Axis of Rotation (Dir 07):") self.axis_z007_entry = QLineEdit() self.axis_z007_entry.editingFinished.connect(self.set_z007) self.axis_z008_label = QLabel("Axis of Rotation (Dir 08):") self.axis_z008_entry = QLineEdit() self.axis_z008_entry.editingFinished.connect(self.set_z008) self.axis_z009_label = QLabel("Axis of Rotation (Dir 09):") self.axis_z009_entry = QLineEdit() self.axis_z009_entry.editingFinished.connect(self.set_z009) self.axis_z010_label = QLabel("Axis of Rotation (Dir 10):") self.axis_z010_entry = QLineEdit() self.axis_z010_entry.editingFinished.connect(self.set_z010) self.axis_z011_label = QLabel("Axis of Rotation (Dir 11):") self.axis_z011_entry = QLineEdit() self.axis_z011_entry.editingFinished.connect(self.set_z011) self.stitch_button = QPushButton("Stitch") self.stitch_button.clicked.connect(self.stitch_button_pressed) self.stitch_button.setStyleSheet("color:royalblue;font-weight:bold") self.delete_button = QPushButton("Delete output dir") self.delete_button.clicked.connect(self.delete_button_pressed) self.help_button = QPushButton("Help") self.help_button.clicked.connect(self.help_button_pressed) self.import_parameters_button = QPushButton("Import Parameters from File") self.import_parameters_button.clicked.connect(self.import_parameters_button_pressed) self.save_parameters_button = QPushButton("Save Parameters to File") self.save_parameters_button.clicked.connect(self.save_parameters_button_pressed) self.set_layout() def set_layout(self): layout = QGridLayout() layout.addWidget(self.input_dir_button, 0, 0, 1, 4) layout.addWidget(self.input_dir_entry, 1, 0, 1, 4) layout.addWidget(self.temp_dir_button, 2, 0, 1, 4) layout.addWidget(self.temp_dir_entry, 3, 0, 1, 4) layout.addWidget(self.output_dir_button, 4, 0, 1, 4) layout.addWidget(self.output_dir_entry, 5, 0, 1, 4) layout.addWidget(self.crop_checkbox, 6, 0, 1, 4) layout.addWidget(self.axis_bottom_label, 7, 0) layout.addWidget(self.axis_bottom_entry, 7, 1) layout.addWidget(self.axis_top_label, 7, 2) layout.addWidget(self.axis_top_entry, 7, 3) self.axis_group.setCheckable(True) self.axis_group.setChecked(False) axis_layout = QGridLayout() axis_layout.addWidget(self.axis_z000_label, 0, 0) axis_layout.addWidget(self.axis_z000_entry, 0, 1) axis_layout.addWidget(self.axis_z006_label, 0, 2) axis_layout.addWidget(self.axis_z006_entry, 0, 3) axis_layout.addWidget(self.axis_z001_label, 1, 0) axis_layout.addWidget(self.axis_z001_entry, 1, 1) axis_layout.addWidget(self.axis_z007_label, 1, 2) axis_layout.addWidget(self.axis_z007_entry, 1, 3) axis_layout.addWidget(self.axis_z002_label, 2, 0) axis_layout.addWidget(self.axis_z002_entry, 2, 1) axis_layout.addWidget(self.axis_z008_label, 2, 2) axis_layout.addWidget(self.axis_z008_entry, 2, 3) axis_layout.addWidget(self.axis_z003_label, 3, 0) axis_layout.addWidget(self.axis_z003_entry, 3, 1) axis_layout.addWidget(self.axis_z009_label, 3, 2) axis_layout.addWidget(self.axis_z009_entry, 3, 3) axis_layout.addWidget(self.axis_z004_label, 4, 0) axis_layout.addWidget(self.axis_z004_entry, 4, 1) axis_layout.addWidget(self.axis_z010_label, 4, 2) axis_layout.addWidget(self.axis_z010_entry, 4, 3) axis_layout.addWidget(self.axis_z005_label, 5, 0) axis_layout.addWidget(self.axis_z005_entry, 5, 1) axis_layout.addWidget(self.axis_z011_label, 5, 2) axis_layout.addWidget(self.axis_z011_entry, 5, 3) self.axis_group.setLayout(axis_layout) self.axis_group.setTabOrder(self.axis_z000_entry, self.axis_z001_entry) self.axis_group.setTabOrder(self.axis_z001_entry, self.axis_z002_entry) self.axis_group.setTabOrder(self.axis_z002_entry, self.axis_z003_entry) self.axis_group.setTabOrder(self.axis_z003_entry, self.axis_z004_entry) self.axis_group.setTabOrder(self.axis_z004_entry, self.axis_z005_entry) self.axis_group.setTabOrder(self.axis_z005_entry, self.axis_z006_entry) self.axis_group.setTabOrder(self.axis_z006_entry, self.axis_z007_entry) self.axis_group.setTabOrder(self.axis_z007_entry, self.axis_z008_entry) self.axis_group.setTabOrder(self.axis_z008_entry, self.axis_z009_entry) self.axis_group.setTabOrder(self.axis_z009_entry, self.axis_z010_entry) self.axis_group.setTabOrder(self.axis_z010_entry, self.axis_z011_entry) layout.addWidget(self.axis_group, 8, 0, 1, 4) layout.addWidget(self.help_button, 9, 0) layout.addWidget(self.delete_button, 9, 1) layout.addWidget(self.stitch_button, 9, 2, 1, 2) layout.addWidget(self.import_parameters_button, 10, 0, 1, 2) layout.addWidget(self.save_parameters_button, 10, 2, 1, 2) self.setLayout(layout) def init_values(self): self.parameters = {'parameters_type': '360_multi_stitch'} self.parameters['360multi_input_dir'] = os.path.expanduser('~')#"~/"#os.getcwd() self.input_dir_entry.setText(self.parameters['360multi_input_dir']) self.parameters['360multi_temp_dir'] = os.path.join( os.path.expanduser('~'), "tmp-batch360stitch") self.temp_dir_entry.setText(self.parameters['360multi_temp_dir']) self.parameters['360multi_output_dir'] = os.path.expanduser('~') self.output_dir_entry.setText(self.parameters['360multi_output_dir']) self.parameters['360multi_crop_projections'] = True self.crop_checkbox.setChecked(self.parameters['360multi_crop_projections']) self.parameters['360multi_bottom_axis'] = 245 self.axis_bottom_entry.setText(str(self.parameters['360multi_bottom_axis'])) self.parameters['360multi_top_axis'] = 245 self.axis_top_entry.setText(str(self.parameters['360multi_top_axis'])) self.parameters['360multi_axis'] = self.parameters['360multi_bottom_axis'] self.parameters['360multi_manual_axis'] = False self.parameters['360multi_axis_dict'] = dict.fromkeys(['z000', 'z001', 'z002', 'z003', 'z004', 'z005', 'z006', 'z007', 'z008', 'z009', 'z010', 'z011'], 200) def update_parameters(self, new_parameters): LOG.debug("Update parameters") if new_parameters['parameters_type'] != '360_multi_stitch': print("Error: Invalid parameter file type: " + str(new_parameters['parameters_type'])) return -1 # Update parameters dictionary (which is passed to auto_stitch_funcs) self.parameters = new_parameters # Update displayed parameters for GUI self.input_dir_entry.setText(self.parameters['360multi_input_dir']) self.temp_dir_entry.setText(self.parameters['360multi_temp_dir']) self.output_dir_entry.setText(self.parameters['360multi_output_dir']) self.crop_checkbox.setChecked(self.parameters['360multi_crop_projections']) self.axis_bottom_entry.setText(str(self.parameters['360multi_bottom_axis'])) self.axis_top_entry.setText(str(self.parameters['360multi_top_axis'])) self.axis_group.setChecked(bool(self.parameters['360multi_manual_axis'])) self.axis_z000_entry.setText(str(self.parameters['360multi_axis_dict']['z000'])) self.axis_z001_entry.setText(str(self.parameters['360multi_axis_dict']['z001'])) self.axis_z002_entry.setText(str(self.parameters['360multi_axis_dict']['z002'])) self.axis_z003_entry.setText(str(self.parameters['360multi_axis_dict']['z003'])) self.axis_z004_entry.setText(str(self.parameters['360multi_axis_dict']['z004'])) self.axis_z005_entry.setText(str(self.parameters['360multi_axis_dict']['z005'])) self.axis_z006_entry.setText(str(self.parameters['360multi_axis_dict']['z006'])) self.axis_z007_entry.setText(str(self.parameters['360multi_axis_dict']['z007'])) self.axis_z008_entry.setText(str(self.parameters['360multi_axis_dict']['z008'])) self.axis_z009_entry.setText(str(self.parameters['360multi_axis_dict']['z009'])) self.axis_z010_entry.setText(str(self.parameters['360multi_axis_dict']['z010'])) self.axis_z011_entry.setText(str(self.parameters['360multi_axis_dict']['z011'])) return 0 def input_button_pressed(self): LOG.debug("Input button pressed") dir_explore = QFileDialog(self) self.parameters['360multi_input_dir'] = dir_explore.getExistingDirectory() self.input_dir_entry.setText(self.parameters['360multi_input_dir']) def set_input_entry(self): LOG.debug("Input directory: " + str(self.input_dir_entry.text())) self.parameters['360multi_input_dir'] = str(self.input_dir_entry.text()) def temp_button_pressed(self): LOG.debug("Temp button pressed") dir_explore = QFileDialog(self) self.parameters['360multi_temp_dir'] = dir_explore.getExistingDirectory() self.temp_dir_entry.setText(self.parameters['360multi_temp_dir']) def set_temp_entry(self): LOG.debug("Temp directory: " + str(self.temp_dir_entry.text())) self.parameters['360multi_temp_dir'] = str(self.temp_dir_entry.text()) def output_button_pressed(self): LOG.debug("Output button pressed") dir_explore = QFileDialog(self) self.parameters['360multi_output_dir'] = dir_explore.getExistingDirectory() self.output_dir_entry.setText(self.parameters['360multi_output_dir']) def set_output_entry(self): LOG.debug("Output directory: " + str(self.output_dir_entry.text())) self.parameters['360multi_output_dir'] = str(self.output_dir_entry.text()) def set_crop_projections_checkbox(self): LOG.debug("Crop projections: " + str(self.crop_checkbox.isChecked())) self.parameters['360multi_crop_projections'] = bool(self.crop_checkbox.isChecked()) def set_axis_bottom(self): LOG.debug("Axis Bottom : " + str(self.axis_bottom_entry.text())) self.parameters['360multi_bottom_axis'] = int(self.axis_bottom_entry.text()) def set_axis_top(self): LOG.debug("Axis Top: " + str(self.axis_top_entry.text())) self.parameters['360multi_top_axis'] = int(self.axis_top_entry.text()) def set_axis_group(self): if self.axis_group.isChecked(): self.axis_bottom_label.setEnabled(False) self.axis_bottom_entry.setEnabled(False) self.axis_top_label.setEnabled(False) self.axis_top_entry.setEnabled(False) self.parameters['360multi_manual_axis'] = True LOG.debug("Enter axis of rotation manually: " + str(self.parameters['360multi_manual_axis'])) else: self.axis_bottom_label.setEnabled(True) self.axis_bottom_entry.setEnabled(True) self.axis_top_label.setEnabled(True) self.axis_top_entry.setEnabled(True) self.parameters['360multi_manual_axis'] = False LOG.debug("Enter axis of rotation manually: " + str(self.parameters['360multi_manual_axis'])) def set_z000(self): LOG.debug("z000 axis: " + str(self.axis_z000_entry.text())) self.parameters['360multi_axis_dict']['z000'] = int(self.axis_z000_entry.text()) def set_z001(self): LOG.debug("z001 axis: " + str(self.axis_z001_entry.text())) self.parameters['360multi_axis_dict']['z001'] = int(self.axis_z001_entry.text()) def set_z002(self): LOG.debug("z002 axis: " + str(self.axis_z002_entry.text())) self.parameters['360multi_axis_dict']['z002'] = int(self.axis_z002_entry.text()) def set_z003(self): LOG.debug("z003 axis: " + str(self.axis_z003_entry.text())) self.parameters['360multi_axis_dict']['z003'] = int(self.axis_z003_entry.text()) def set_z004(self): LOG.debug("z004 axis: " + str(self.axis_z004_entry.text())) self.parameters['360multi_axis_dict']['z004'] = int(self.axis_z004_entry.text()) def set_z005(self): LOG.debug("z005 axis: " + str(self.axis_z005_entry.text())) self.parameters['360multi_axis_dict']['z005'] = int(self.axis_z005_entry.text()) def set_z006(self): LOG.debug("z006 axis: " + str(self.axis_z006_entry.text())) self.parameters['360multi_axis_dict']['z006'] = int(self.axis_z006_entry.text()) def set_z007(self): LOG.debug("z007 axis: " + str(self.axis_z007_entry.text())) self.parameters['360multi_axis_dict']['z007'] = int(self.axis_z007_entry.text()) def set_z008(self): LOG.debug("z008 axis: " + str(self.axis_z008_entry.text())) self.parameters['360multi_axis_dict']['z008'] = int(self.axis_z008_entry.text()) def set_z009(self): LOG.debug("z009 axis: " + str(self.axis_z009_entry.text())) self.parameters['360multi_axis_dict']['z009'] = int(self.axis_z009_entry.text()) def set_z010(self): LOG.debug("z010 axis: " + str(self.axis_z010_entry.text())) self.parameters['360multi_axis_dict']['z010'] = int(self.axis_z010_entry.text()) def set_z011(self): LOG.debug("z011 axis: " + str(self.axis_z011_entry.text())) self.parameters['360multi_axis_dict']['z011'] = int(self.axis_z011_entry.text()) def stitch_button_pressed(self): LOG.debug("Stitch button pressed") self.get_fdt_names_on_stitch_pressed.emit() if os.path.exists(self.parameters['360multi_temp_dir']): qm = QMessageBox() rep = qm.question(self, '', "Temporary dir is not empty. Is it safe to delete it?", qm.Yes | qm.No) if rep == qm.Yes: try: rmtree(self.parameters['360multi_temp_dir']) except: warning_message("Problems with deleting directory") else: return if os.path.exists(self.parameters['360multi_output_dir']): warning_message('Output directory exists. Delete it or select another one.') return # if os.path.exists(self.parameters['360multi_output_dir']): # # raise ValueError('Output directory exists') # qm = QMessageBox() # rep = qm.question(self, '', "Output dir is not empty. Can I delete it?", qm.Yes | qm.No) # if rep == qm.Yes: # os.system('rm -r {}'.format(self.parameters['360multi_output_dir'])) # else: # return print("======= Begin 360 Multi-Stitch =======") main_360_mp_depth2(self.parameters) if os.path.isdir(self.parameters['360multi_output_dir']): params_file_path = os.path.join(self.parameters['360multi_output_dir'], '360_multi_stitch_params.yaml') params.save_parameters(self.parameters, params_file_path) print("==== Waiting for Next Task ====") def delete_button_pressed(self): print("---- Deleting Data From Output Directory ----") LOG.debug("Delete button pressed") qm = QMessageBox() rep = qm.question(self, '', "Is it safe to delete the directory?", qm.Yes | qm.No) if rep == qm.Yes: try: rmtree(self.parameters['360multi_output_dir']) except: warning_message("Problems with deleting directory") else: return def help_button_pressed(self): LOG.debug("Help button pressed") h = "Stitches images horizontally\n" h += "Directory structure is, f.i., Input/000, Input/001,...Input/00N\n" h += "Each 000, 001, ... 00N directory must have identical subdirectory \"Type\"\n" h += "Selected range of images from \"Type\" directory will be stitched vertically\n" h += "across all subdirectories in the Input directory" h += "to be added as options:\n" h += "(1) orthogonal reslicing, (2) interpolation, (3) horizontal stitching" QMessageBox.information(self, "Help", h) def import_parameters_button_pressed(self): LOG.debug("Import params button clicked") dir_explore = QFileDialog(self) params_file_path = dir_explore.getOpenFileName(filter="*.yaml") try: file_in = open(params_file_path[0], 'r') new_parameters = yaml.load(file_in, Loader=yaml.FullLoader) if self.update_parameters(new_parameters) == 0: print("Parameters file loaded from: " + str(params_file_path[0])) except FileNotFoundError: print("You need to select a valid input file") def save_parameters_button_pressed(self): LOG.debug("Save params button clicked") dir_explore = QFileDialog(self) params_file_path = dir_explore.getSaveFileName(filter="*.yaml") garbage, file_name = os.path.split(params_file_path[0]) file_extension = os.path.splitext(file_name) # If the user doesn't enter the .yaml extension then append it to filepath if file_extension[-1] == "": file_path = params_file_path[0] + ".yaml" else: file_path = params_file_path[0] try: file_out = open(file_path, 'w') yaml.dump(self.parameters, file_out) print("Parameters file saved at: " + str(file_path)) except FileNotFoundError: print("You need to select a directory and use a valid file name") tofu-0.12.0/tofu/ez/GUI/Stitch_tools_tab/ez_360_overlap_qt.py000066400000000000000000000300301423713721100237030ustar00rootroot00000000000000from PyQt5.QtWidgets import ( QGroupBox, QPushButton, QCheckBox, QLabel, QLineEdit, QGridLayout, QFileDialog, QMessageBox, ) from PyQt5.QtCore import pyqtSignal import logging from shutil import rmtree import yaml import os from tofu.ez.Helpers.find_360_overlap import find_overlap import tofu.ez.params as params import getpass #TODO Make all stitching tools compatible with the bigtiffs LOG = logging.getLogger(__name__) class Overlap360Group(QGroupBox): get_fdt_names_on_stitch_pressed = pyqtSignal() def __init__(self): super().__init__() self.setTitle("Reconstruct one slice with different axis of rotation positions for half-acqusition mode data set(s)") self.setStyleSheet('QGroupBox {color: Orange;}') self.input_dir_button = QPushButton("Select input directory") self.input_dir_button.clicked.connect(self.input_button_pressed) self.input_dir_entry = QLineEdit() self.input_dir_entry.editingFinished.connect(self.set_input_entry) self.temp_dir_button = QPushButton("Select temp directory") self.temp_dir_button.clicked.connect(self.temp_button_pressed) self.temp_dir_entry = QLineEdit() self.temp_dir_entry.editingFinished.connect(self.set_temp_entry) self.output_dir_button = QPushButton("Select output directory") self.output_dir_button.clicked.connect(self.output_button_pressed) self.output_dir_entry = QLineEdit() self.output_dir_entry.editingFinished.connect(self.set_output_entry) self.pixel_row_label = QLabel("Pixel row to be used for sinogram") self.pixel_row_entry = QLineEdit() self.pixel_row_entry.editingFinished.connect(self.set_pixel_row) self.min_label = QLabel("Lower limit of stitch/axis search range") self.min_entry = QLineEdit() self.min_entry.editingFinished.connect(self.set_lower_limit) self.max_label = QLabel("Upper limit of stitch/axis search range") self.max_entry = QLineEdit() self.max_entry.editingFinished.connect(self.set_upper_limit) self.step_label = QLabel("Value by which to increment through search range") self.step_entry = QLineEdit() self.step_entry.editingFinished.connect(self.set_increment) self.axis_on_left = QCheckBox("Apply ring removal") self.axis_on_left.setEnabled(False) self.axis_on_left.stateChanged.connect(self.set_axis_checkbox) self.help_button = QPushButton("Help") self.help_button.clicked.connect(self.help_button_pressed) self.find_overlap_button = QPushButton("Generate slices") self.find_overlap_button.clicked.connect(self.overlap_button_pressed) self.find_overlap_button.setStyleSheet("color:royalblue;font-weight:bold") self.import_parameters_button = QPushButton("Import Parameters from File") self.import_parameters_button.clicked.connect(self.import_parameters_button_pressed) self.save_parameters_button = QPushButton("Save Parameters to File") self.save_parameters_button.clicked.connect(self.save_parameters_button_pressed) self.set_layout() def set_layout(self): layout = QGridLayout() layout.addWidget(self.input_dir_button, 0, 0, 1, 2) layout.addWidget(self.input_dir_entry, 1, 0, 1, 2) layout.addWidget(self.temp_dir_button, 2, 0, 1, 2) layout.addWidget(self.temp_dir_entry, 3, 0, 1, 2) layout.addWidget(self.output_dir_button, 4, 0, 1, 2) layout.addWidget(self.output_dir_entry, 5, 0, 1, 2) layout.addWidget(self.pixel_row_label, 6, 0) layout.addWidget(self.pixel_row_entry, 6, 1) layout.addWidget(self.min_label, 7, 0) layout.addWidget(self.min_entry, 7, 1) layout.addWidget(self.max_label, 8, 0) layout.addWidget(self.max_entry, 8, 1) layout.addWidget(self.step_label, 9, 0) layout.addWidget(self.step_entry, 9, 1) layout.addWidget(self.axis_on_left, 10, 0) layout.addWidget(self.help_button, 11, 0) layout.addWidget(self.find_overlap_button, 11, 1) layout.addWidget(self.import_parameters_button, 12, 0) layout.addWidget(self.save_parameters_button, 12, 1) self.setLayout(layout) def init_values(self): self.parameters = {'parameters_type': '360_overlap'} self.parameters['360overlap_input_dir'] = os.path.expanduser('~') self.input_dir_entry.setText(self.parameters['360overlap_input_dir']) self.parameters['360overlap_temp_dir'] = os.path.join( os.path.expanduser('~'), "tmp-360axis-search") self.temp_dir_entry.setText(self.parameters['360overlap_temp_dir']) self.parameters['360overlap_output_dir'] = os.path.join( os.path.expanduser('~'), "ezufo-360axis-search") self.output_dir_entry.setText(self.parameters['360overlap_output_dir']) self.parameters['360overlap_row'] = 200 self.pixel_row_entry.setText(str(self.parameters['360overlap_row'])) self.parameters['360overlap_lower_limit'] = 100 self.min_entry.setText(str(self.parameters['360overlap_lower_limit'])) self.parameters['360overlap_upper_limit'] = 200 self.max_entry.setText(str(self.parameters['360overlap_upper_limit'])) self.parameters['360overlap_increment'] = 1 self.step_entry.setText(str(self.parameters['360overlap_increment'])) self.parameters['360overlap_axis_on_left'] = True self.axis_on_left.setChecked(bool(self.parameters['360overlap_axis_on_left'])) def update_parameters(self, new_parameters): LOG.debug("Update parameters") if new_parameters['parameters_type'] != '360_overlap': print("Error: Invalid parameter file type: " + str(new_parameters['parameters_type'])) return -1 # Update parameters dictionary (which is passed to auto_stitch_funcs) self.parameters = new_parameters # Update displayed parameters for GUI self.input_dir_entry.setText(self.parameters['360overlap_input_dir']) self.temp_dir_entry.setText(self.parameters['360overlap_temp_dir']) self.output_dir_entry.setText(self.parameters['360overlap_output_dir']) self.pixel_row_entry.setText(str(self.parameters['360overlap_row'])) self.min_entry.setText(str(self.parameters['360overlap_lower_limit'])) self.max_entry.setText(str(self.parameters['360overlap_upper_limit'])) self.step_entry.setText(str(self.parameters['360overlap_increment'])) self.axis_on_left.setChecked(bool(self.parameters['360overlap_axis_on_left'])) def input_button_pressed(self): LOG.debug("Select input button pressed") dir_explore = QFileDialog(self) self.parameters['360overlap_input_dir'] = dir_explore.getExistingDirectory() self.input_dir_entry.setText(self.parameters['360overlap_input_dir']) def set_input_entry(self): LOG.debug("Input: " + str(self.input_dir_entry.text())) self.parameters['360overlap_input_dir'] = str(self.input_dir_entry.text()) def temp_button_pressed(self): LOG.debug("Select temp button pressed") dir_explore = QFileDialog(self) self.parameters['360overlap_temp_dir'] = dir_explore.getExistingDirectory() self.temp_dir_entry.setText(self.parameters['360overlap_temp_dir']) def set_temp_entry(self): LOG.debug("Temp: " + str(self.temp_dir_entry.text())) self.parameters['360overlap_temp_dir'] = str(self.temp_dir_entry.text()) def output_button_pressed(self): LOG.debug("Select output button pressed") dir_explore = QFileDialog(self) self.parameters['360overlap_output_dir'] = dir_explore.getExistingDirectory() self.output_dir_entry.setText(self.parameters['360overlap_output_dir']) def set_output_entry(self): LOG.debug("Output: " + str(self.output_dir_entry.text())) self.parameters['360overlap_output_dir'] = str(self.output_dir_entry.text()) def set_pixel_row(self): LOG.debug("Pixel row: " + str(self.pixel_row_entry.text())) self.parameters['360overlap_row'] = int(self.pixel_row_entry.text()) def set_lower_limit(self): LOG.debug("Lower limit: " + str(self.min_entry.text())) self.parameters['360overlap_lower_limit'] = int(self.min_entry.text()) def set_upper_limit(self): LOG.debug("Upper limit: " + str(self.max_entry.text())) self.parameters['360overlap_upper_limit'] = int(self.max_entry.text()) def set_increment(self): LOG.debug("Value of increment: " + str(self.step_entry.text())) self.parameters['360overlap_increment'] = int(self.step_entry.text()) def set_axis_checkbox(self): LOG.debug("Is rotation axis on left-hand-side?: " + str(self.axis_on_left.isChecked())) self.parameters['360overlap_axis_on_left'] = bool(self.axis_on_left.isChecked()) def overlap_button_pressed(self): LOG.debug("Find overlap button pressed") if os.path.exists(self.parameters['360overlap_output_dir']) or \ os.path.exists(self.parameters['360overlap_temp_dir']): qm = QMessageBox() rep = qm.question(self, '', "Output directory or(and) temporary dir exist. Can I delete both?", qm.Yes | qm.No) if rep == qm.Yes: try: rmtree(self.parameters['360overlap_output_dir']) rmtree(self.parameters['360overlap_temp_dir']) except: pass else: return os.makedirs(self.parameters['360overlap_temp_dir']) os.makedirs(self.parameters['360overlap_output_dir']) find_overlap(self.parameters) if os.path.exists(self.parameters['360overlap_output_dir']): params_file_path = os.path.join(self.parameters['360overlap_output_dir'], '360_overlap_params.yaml') params.save_parameters(self.parameters, params_file_path) def help_button_pressed(self): LOG.debug("Help button pressed") h = "This script takes as input a CT scan that has been collected in 'half-acquisition' mode" h += " and produces a series of reconstructed slices, each of which are generated by cropping and" h += " concatenating opposing projections together over a range of 'overlap' values (i.e. the pixel column" h += " at which the images are cropped and concatenated)." h += " The objective is to review this series of images to determine the pixel column at which the axis of rotation" h += " is located (much like the axis search function commonly used in reconstruction software)." QMessageBox.information(self, "Help", h) def import_parameters_button_pressed(self): LOG.debug("Import params button clicked") dir_explore = QFileDialog(self) params_file_path = dir_explore.getOpenFileName(filter="*.yaml") try: file_in = open(params_file_path[0], 'r') new_parameters = yaml.load(file_in, Loader=yaml.FullLoader) if self.update_parameters(new_parameters) == 0: print("Parameters file loaded from: " + str(params_file_path[0])) except FileNotFoundError: print("You need to select a valid input file") def save_parameters_button_pressed(self): LOG.debug("Save params button clicked") dir_explore = QFileDialog(self) params_file_path = dir_explore.getSaveFileName(filter="*.yaml") garbage, file_name = os.path.split(params_file_path[0]) file_extension = os.path.splitext(file_name) # If the user doesn't enter the .yaml extension then append it to filepath if file_extension[-1] == "": file_path = params_file_path[0] + ".yaml" else: file_path = params_file_path[0] try: file_out = open(file_path, 'w') yaml.dump(self.parameters, file_out) print("Parameters file saved at: " + str(file_path)) except FileNotFoundError: print("You need to select a directory and use a valid file name") tofu-0.12.0/tofu/ez/GUI/Stitch_tools_tab/ezmview_qt.py000066400000000000000000000243361423713721100226470ustar00rootroot00000000000000import os import logging from PyQt5.QtWidgets import ( QGroupBox, QPushButton, QLineEdit, QLabel, QCheckBox, QGridLayout, QFileDialog, QMessageBox, ) import yaml from tofu.ez.Helpers.mview_main import main_prep LOG = logging.getLogger(__name__) class EZMViewGroup(QGroupBox): def __init__(self): super().__init__() self.args = {} self.e_indir = "" self.e_nproj = 0 self.e_nflats = 0 self.e_ndarks = 0 self.e_nviews = 0 self.e_noflats2 = False self.e_Andor = False self.setTitle("Split a sequence of tif files over flats/darks/tomo directories") self.setStyleSheet("QGroupBox {color: green;}") self.input_dir_button = QPushButton() self.input_dir_button.setText("Select directory with a CT sequence") self.input_dir_button.clicked.connect(self.select_directory) self.input_dir_entry = QLineEdit() self.input_dir_entry.editingFinished.connect(self.set_directory_entry) self.num_projections_label = QLabel() self.num_projections_label.setText("Number of projections") self.num_projections_entry = QLineEdit() self.num_projections_entry.editingFinished.connect(self.set_num_projections) self.num_flats_label = QLabel() self.num_flats_label.setText("Number of flats") self.num_flats_entry = QLineEdit() self.num_flats_entry.editingFinished.connect(self.set_num_flats) self.num_darks_label = QLabel() self.num_darks_label.setText("Number of darks") self.num_darks_entry = QLineEdit() self.num_darks_entry.editingFinished.connect(self.set_num_darks) self.num_vert_steps_label = QLabel() self.num_vert_steps_label.setText("Number of CT sets in the sequence") self.num_vert_steps_entry = QLineEdit() self.num_vert_steps_entry.editingFinished.connect(self.set_num_steps) self.no_trailing_flats_darks_checkbox = QCheckBox() self.no_trailing_flats_darks_checkbox.setText("No trailing flats/darks") self.no_trailing_flats_darks_checkbox.stateChanged.connect(self.set_trailing_checkbox) self.filenames_without_padding_checkbox = QCheckBox() self.filenames_without_padding_checkbox.setText("File names without zero padding") self.filenames_without_padding_checkbox.stateChanged.connect(self.set_file_names_checkbox) self.help_button = QPushButton() self.help_button.setText("Help") self.help_button.clicked.connect(self.help_button_pressed) self.undo_button = QPushButton() self.undo_button.setText("Undo") self.undo_button.clicked.connect(self.undo_button_pressed) self.convert_button = QPushButton() self.convert_button.setText("Convert") self.convert_button.clicked.connect(self.convert_button_pressed) self.convert_button.setStyleSheet("color:royalblue;font-weight:bold") self.save_parameters_button = QPushButton("Save Parameters to File") self.save_parameters_button.clicked.connect(self.save_parameters_button_pressed) self.import_parameters_button = QPushButton("Import Parameters from File") self.import_parameters_button.clicked.connect(self.import_parameters_button_pressed) self.set_layout() def set_layout(self): layout = QGridLayout() layout.addWidget(self.input_dir_button, 0, 0, 1, 3) layout.addWidget(self.input_dir_entry, 1, 0, 1, 3) layout.addWidget(self.num_projections_label, 2, 0) layout.addWidget(self.num_projections_entry, 2, 1, 1, 2) layout.addWidget(self.num_flats_label, 3, 0) layout.addWidget(self.num_flats_entry, 3, 1, 1, 2) layout.addWidget(self.num_darks_label, 4, 0) layout.addWidget(self.num_darks_entry, 4, 1, 1, 2) layout.addWidget(self.num_vert_steps_label, 5, 0) layout.addWidget(self.num_vert_steps_entry, 5, 1, 1, 2) layout.addWidget(self.no_trailing_flats_darks_checkbox, 6, 0) layout.addWidget(self.filenames_without_padding_checkbox, 6, 1, 1, 2) layout.addWidget(self.help_button, 7, 0, 1, 1) layout.addWidget(self.undo_button, 7, 1, 1, 1) layout.addWidget(self.convert_button, 7, 2, 1, 1) layout.addWidget(self.import_parameters_button, 8, 0, 1, 2) layout.addWidget(self.save_parameters_button, 8, 2, 1, 1) self.setLayout(layout) def init_values(self): self.parameters = {'parameters_type': 'ez_mview'} self.input_dir_entry.setText(os.getcwd()) self.parameters['ezmview_input_dir'] = os.getcwd() self.num_projections_entry.setText("3000") self.parameters['ezmview_num_projections'] = 3000 self.num_flats_entry.setText("10") self.parameters['ezmview_num_flats'] = 10 self.num_darks_entry.setText("10") self.parameters['ezmview_num_darks'] = 10 self.num_vert_steps_entry.setText("1") self.parameters['ezmview_num_vertical_steps'] = 1 self.no_trailing_flats_darks_checkbox.setChecked(False) self.parameters['ezmview_flats2'] = False self.filenames_without_padding_checkbox.setChecked(False) self.parameters['ezmview_zero_padding'] = False def update_parameters(self, new_parameters): LOG.debug("Update parameters") if new_parameters['parameters_type'] != 'ez_mview': print("Error: Invalid parameter file type: " + str(new_parameters['parameters_type'])) return -1 # Update parameters dictionary (which is passed to auto_stitch_funcs) self.parameters = new_parameters # Update displayed parameters for GUI self.input_dir_entry.setText(str(self.parameters['ezmview_input_dir'])) self.num_projections_entry.setText(str(self.parameters['ezmview_num_projections'])) self.num_flats_entry.setText(str(self.parameters['ezmview_num_flats'])) self.num_darks_entry.setText(str(self.parameters['ezmview_num_darks'])) self.num_vert_steps_entry.setText(str(self.parameters['ezmview_num_vertical_steps'])) self.no_trailing_flats_darks_checkbox.setChecked(bool(self.parameters['ezmview_flats2'])) self.filenames_without_padding_checkbox.setChecked(bool(self.parameters['ezmview_zero_padding'])) def select_directory(self): LOG.debug("Select directory button pressed") dir_explore = QFileDialog(self) directory = dir_explore.getExistingDirectory() self.input_dir_entry.setText(directory) self.parameters['ezmview_input_dir'] = directory def set_directory_entry(self): LOG.debug("Directory entry: " + str(self.input_dir_entry.text())) self.parameters['ezmview_input_dir'] = str(self.input_dir_entry.text()) def set_num_projections(self): LOG.debug("Num projections: " + str(self.num_projections_entry.text())) self.parameters['ezmview_num_projections'] = int(self.num_projections_entry.text()) def set_num_flats(self): LOG.debug("Num flats: " + str(self.num_flats_entry.text())) self.parameters['ezmview_num_flats'] = int(self.num_flats_entry.text()) def set_num_darks(self): LOG.debug("Num darks: " + str(self.num_darks_entry.text())) self.parameters['ezmview_num_darks'] = int(self.num_darks_entry.text()) def set_num_steps(self): LOG.debug("Num steps: " + str(self.num_vert_steps_entry.text())) self.parameters['ezmview_num_vertical_steps'] = int(self.num_vert_steps_entry.text()) def set_trailing_checkbox(self): LOG.debug("No trailing: " + str(self.no_trailing_flats_darks_checkbox.isChecked())) self.parameters['ezmview_flats2'] = bool(self.no_trailing_flats_darks_checkbox.isChecked()) def set_file_names_checkbox(self): LOG.debug("File names without zero padding: " + str(self.filenames_without_padding_checkbox.isChecked())) self.parameters['ezmview_zero_padding'] = \ bool(self.filenames_without_padding_checkbox.isChecked()) def convert_button_pressed(self): LOG.debug("Convert button pressed") LOG.debug(self.parameters) main_prep(self.parameters) def undo_button_pressed(self): LOG.debug("Undo button pressed") cmd = "find {} -type f -name \"*.tif\" -exec mv -t {} {{}} +" cmd = cmd.format(str(self.parameters['ezmview_input_dir']), str(self.parameters['ezmview_input_dir'])) os.system(cmd) def help_button_pressed(self): LOG.debug("Help button pressed") h = "Distributes a sequence of CT frames in flats/darks/tomo/flats2 directories\n" h += "assuming that acqusition sequence is flats->darks->tomo->flats2\n" h += 'Use only for sequences with flat fields acquired at 0 and 180!\n' h += "Conversions happens in-place but can be undone" QMessageBox.information(self, "Help", h) def import_parameters_button_pressed(self): LOG.debug("Import params button clicked") dir_explore = QFileDialog(self) params_file_path = dir_explore.getOpenFileName(filter="*.yaml") try: file_in = open(params_file_path[0], 'r') new_parameters = yaml.load(file_in, Loader=yaml.FullLoader) if self.update_parameters(new_parameters) == 0: print("Parameters file loaded from: " + str(params_file_path[0])) except FileNotFoundError: print("You need to select a valid input file") def save_parameters_button_pressed(self): LOG.debug("Save params button clicked") dir_explore = QFileDialog(self) params_file_path = dir_explore.getSaveFileName(filter="*.yaml") garbage, file_name = os.path.split(params_file_path[0]) file_extension = os.path.splitext(file_name) # If the user doesn't enter the .yaml extension then append it to filepath if file_extension[-1] == "": file_path = params_file_path[0] + ".yaml" else: file_path = params_file_path[0] try: file_out = open(file_path, 'w') yaml.dump(self.parameters, file_out) print("Parameters file saved at: " + str(file_path)) except FileNotFoundError: print("You need to select a directory and use a valid file name") tofu-0.12.0/tofu/ez/GUI/Stitch_tools_tab/ezstitch_qt.py000066400000000000000000000527221423713721100230160ustar00rootroot00000000000000import os from PyQt5.QtWidgets import ( QGroupBox, QPushButton, QCheckBox, QLabel, QLineEdit, QGridLayout, QVBoxLayout, QHBoxLayout, QRadioButton, QFileDialog, QMessageBox, ) from shutil import rmtree import logging import getpass import yaml import tofu.ez.params as params from tofu.ez.Helpers.stitch_funcs import main_sti_mp, main_conc_mp, main_360_mp_depth1 from tofu.ez.GUI.message_dialog import warning_message LOG = logging.getLogger(__name__) class EZStitchGroup(QGroupBox): def __init__(self): super().__init__() self.setTitle("Vertical stitching and reslicing tool") self.setStyleSheet('QGroupBox {color: purple;}') self.input_dir_button = QPushButton() self.input_dir_button.setText("Select input directory") self.input_dir_button.setToolTip("Normally contains a bunch of directories at the first depth level\n" \ "each of which has a subdirectory with the same name (second depth level). \n" "Images in these second-level subdirectories will be stitched together.") self.input_dir_button.clicked.connect(self.input_button_pressed) self.input_dir_entry = QLineEdit() self.input_dir_entry.editingFinished.connect(self.set_input_entry) self.tmp_dir_button = QPushButton() self.tmp_dir_button.setText("Select temporary directory") self.tmp_dir_button.clicked.connect(self.temp_button_pressed) self.tmp_dir_entry = QLineEdit() self.tmp_dir_entry.editingFinished.connect(self.set_temp_entry) self.output_dir_button = QPushButton() self.output_dir_button.setText("Directory to save stitched images") self.output_dir_button.clicked.connect(self.output_button_pressed) self.output_dir_entry = QLineEdit() self.output_dir_entry.editingFinished.connect(self.set_output_entry) self.types_of_images_label = QLabel() tmpstr = "Name of subdirectories which contain the same type of images in every directory in the input" self.types_of_images_label.setToolTip(tmpstr) self.types_of_images_label.setText("Name of subdirectories with the same type of images to stitch(e.g. sli, tomo, proj-pr, etc.)") self.types_of_images_entry = QLineEdit() self.types_of_images_entry.setToolTip(tmpstr) self.types_of_images_entry.editingFinished.connect(self.set_type_images) self.orthogonal_checkbox = QCheckBox() self.orthogonal_checkbox.setText("Stitch orthogonal sections (will reslice images in every subdirectory and then stitch)") self.orthogonal_checkbox.stateChanged.connect(self.set_stitch_checkbox) self.start_stop_step_label = QLabel() self.start_stop_step_label.setText("Which images to be stitched: start,stop,step:") self.start_stop_step_entry = QLineEdit() self.start_stop_step_entry.editingFinished.connect(self.set_start_stop_step) self.sample_moved_down_checkbox = QCheckBox() self.sample_moved_down_checkbox.setText("Sample was moved downwards during scan") self.sample_moved_down_checkbox.stateChanged.connect(self.set_sample_moved_down) self.interpolate_regions_rButton = QRadioButton() self.interpolate_regions_rButton.setText("Interpolate overlapping regions and equalize intensity") self.interpolate_regions_rButton.clicked.connect(self.set_rButton) self.num_overlaps_label = QLabel() self.num_overlaps_label.setText("Number of overlapping rows") self.num_overlaps_entry = QLineEdit() self.num_overlaps_entry.editingFinished.connect(self.set_overlap) self.clip_histogram_checkbox = QCheckBox() self.clip_histogram_checkbox.setText("Clip histogram and convert slices to 8-bit before saving") self.clip_histogram_checkbox.stateChanged.connect(self.set_histogram_checkbox) self.min_value_label = QLabel() self.min_value_label.setText("Min value in 32-bit histogram") self.min_value_entry = QLineEdit() self.min_value_entry.editingFinished.connect(self.set_min_value) self.max_value_label = QLabel() self.max_value_label.setText("Max value in 32-bit histogram") self.max_value_entry = QLineEdit() self.max_value_entry.editingFinished.connect(self.set_max_value) self.concatenate_rButton = QRadioButton() self.concatenate_rButton.setText("Concatenate only") self.concatenate_rButton.clicked.connect(self.set_rButton) self.first_row_label = QLabel() self.first_row_label.setText("First row") self.first_row_entry = QLineEdit() self.first_row_entry.editingFinished.connect(self.set_first_row) self.last_row_label = QLabel() self.last_row_label.setText("Last row") self.last_row_entry = QLineEdit() self.last_row_entry.editingFinished.connect(self.set_last_row) self.half_acquisition_rButton = QRadioButton() self.half_acquisition_rButton.setText("Horizontal stitching of half-acq. mode data (applies to tif images in the Input)") #self.half_acquisition_rButtonYfor a half-acqusition mode data (even number of tif files in the Input directory)") self.half_acquisition_rButton.clicked.connect(self.set_rButton) self.column_of_axis_label = QLabel() self.column_of_axis_label.setText("In which column the axis of rotation is") self.column_of_axis_entry = QLineEdit() self.column_of_axis_entry.editingFinished.connect(self.set_axis_column) self.help_button = QPushButton() self.help_button.setText("Help") self.help_button.clicked.connect(self.help_button_pressed) self.delete_button = QPushButton() self.delete_button.setText("Delete output dir") self.delete_button.clicked.connect(self.delete_button_pressed) self.stitch_button = QPushButton() self.stitch_button.setText("Stitch") self.stitch_button.clicked.connect(self.stitch_button_pressed) self.stitch_button.setStyleSheet("color:royalblue;font-weight:bold") self.import_parameters_button = QPushButton("Import Parameters from File") self.import_parameters_button.clicked.connect(self.import_parameters_button_pressed) self.save_parameters_button = QPushButton("Save Parameters to File") self.save_parameters_button.clicked.connect(self.save_parameters_button_pressed) self.set_layout() def set_layout(self): layout = QGridLayout() vbox1 = QVBoxLayout() vbox1.addWidget(self.input_dir_button) vbox1.addWidget(self.input_dir_entry) vbox1.addWidget(self.tmp_dir_button) vbox1.addWidget(self.tmp_dir_entry) vbox1.addWidget(self.output_dir_button) vbox1.addWidget(self.output_dir_entry) layout.addItem(vbox1, 0, 0) grid = QGridLayout() grid.addWidget(self.types_of_images_label, 0, 0) grid.addWidget(self.types_of_images_entry, 0, 1) grid.addWidget(self.orthogonal_checkbox, 1, 0, 1, 2) grid.addWidget(self.start_stop_step_label, 2, 0) grid.addWidget(self.start_stop_step_entry, 2, 1) grid.addWidget(self.sample_moved_down_checkbox, 3, 0) grid.addWidget(self.interpolate_regions_rButton, 4, 0, 1, 2) grid.addWidget(self.num_overlaps_label, 5, 0) grid.addWidget(self.num_overlaps_entry, 5, 1) grid.addWidget(self.clip_histogram_checkbox, 6, 0) grid.addWidget(self.min_value_label, 7, 0) grid.addWidget(self.min_value_entry, 7, 1) grid.addWidget(self.max_value_label, 8, 0) grid.addWidget(self.max_value_entry, 8, 1) layout.addItem(grid, 1, 0) grid2 = QGridLayout() grid2.addWidget(self.concatenate_rButton, 0, 0, 1, 2) grid2.addWidget(self.first_row_label, 1, 0) grid2.addWidget(self.first_row_entry, 1, 1) grid2.addWidget(self.last_row_label, 1, 2) grid2.addWidget(self.last_row_entry, 1, 3) layout.addItem(grid2, 2, 0) grid3 = QGridLayout() grid3.addWidget(self.half_acquisition_rButton, 0, 0, 1, 2) grid3.addWidget(self.column_of_axis_label, 1, 0) grid3.addWidget(self.column_of_axis_entry, 1, 1) layout.addItem(grid3, 3, 0) grid4 = QGridLayout() grid4.addWidget(self.help_button, 0, 0) grid4.addWidget(self.delete_button, 0, 1) grid4.addWidget(self.stitch_button, 0, 2) grid4.addWidget(self.import_parameters_button, 1, 0, 1, 2) grid4.addWidget(self.save_parameters_button, 1, 2) layout.addItem(grid4, 4, 0) self.setLayout(layout) def init_values(self): self.parameters = {'parameters_type': 'ez_stitch'} self.parameters['ezstitch_input_dir'] = os.path.expanduser('~') self.input_dir_entry.setText(self.parameters['ezstitch_input_dir']) self.parameters['ezstitch_temp_dir'] = os.path.join( os.path.expanduser('~'), "tmp-ezstitch") self.tmp_dir_entry.setText(self.parameters['ezstitch_temp_dir']) self.parameters['ezstitch_output_dir'] = os.path.join( os.path.expanduser('~'), "ezufo-stitched-images") self.output_dir_entry.setText(self.parameters['ezstitch_output_dir']) self.parameters['ezstitch_type_image'] = "sli" self.types_of_images_entry.setText(self.parameters['ezstitch_type_image']) self.parameters['ezstitch_stitch_orthogonal'] = True self.orthogonal_checkbox.setChecked(self.parameters['ezstitch_stitch_orthogonal']) self.parameters['ezstitch_start_stop_step'] = "200,2000,200" self.start_stop_step_entry.setText(self.parameters['ezstitch_start_stop_step']) self.parameters['ezstitch_sample_moved_down'] = False self.sample_moved_down_checkbox.setChecked(self.parameters['ezstitch_sample_moved_down']) self.parameters['ezstitch_stitch_type'] = 0 self.interpolate_regions_rButton.setChecked(True) self.concatenate_rButton.setChecked(False) self.half_acquisition_rButton.setChecked(False) self.parameters['ezstitch_num_overlap_rows'] = 60 self.num_overlaps_entry.setText(str(self.parameters['ezstitch_num_overlap_rows'])) self.parameters['ezstitch_clip_histo'] = False self.clip_histogram_checkbox.setChecked(self.parameters['ezstitch_clip_histo']) self.parameters['ezstitch_histo_min'] = -0.0003 self.min_value_entry.setText(str(self.parameters['ezstitch_histo_min'])) self.parameters['ezstitch_histo_max'] = 0.0002 self.max_value_entry.setText(str(self.parameters['ezstitch_histo_max'])) self.parameters['ezstitch_first_row'] = 40 self.first_row_entry.setText(str(self.parameters['ezstitch_first_row'])) self.parameters['ezstitch_last_row'] = 440 self.last_row_entry.setText(str(self.parameters['ezstitch_last_row'])) self.parameters['ezstitch_axis_of_rotation'] = 245 self.column_of_axis_entry.setText(str(self.parameters['ezstitch_axis_of_rotation'])) def update_parameters(self, new_parameters): LOG.debug("Update parameters") if new_parameters['parameters_type'] != 'ez_stitch': print("Error: Invalid parameter file type: " + str(new_parameters['parameters_type'])) return -1 # Update parameters dictionary (which is passed to auto_stitch_funcs) self.parameters = new_parameters # Update displayed parameters for GUI self.input_dir_entry.setText(self.parameters['ezstitch_input_dir']) self.tmp_dir_entry.setText(self.parameters['ezstitch_temp_dir']) self.output_dir_entry.setText(self.parameters['ezstitch_output_dir']) self.types_of_images_entry.setText(self.parameters['ezstitch_type_image']) self.orthogonal_checkbox.setChecked(self.parameters['ezstitch_stitch_orthogonal']) self.start_stop_step_entry.setText(self.parameters['ezstitch_start_stop_step']) self.sample_moved_down_checkbox.setChecked(self.parameters['ezstitch_sample_moved_down']) if self.parameters['ezstitch_stitch_type'] == 0: self.interpolate_regions_rButton.setChecked(True) elif self.parameters['ezstitch_stitch_type'] == 1: self.concatenate_rButton.setChecked(True) elif self.parameters['ezstitch_stitch_type'] == 2: self.half_acquisition_rButton.setChecked(True) self.num_overlaps_entry.setText(str(self.parameters['ezstitch_num_overlap_rows'])) self.clip_histogram_checkbox.setChecked(self.parameters['ezstitch_clip_histo']) self.min_value_entry.setText(str(self.parameters['ezstitch_histo_min'])) self.max_value_entry.setText(str(self.parameters['ezstitch_histo_max'])) self.first_row_entry.setText(str(self.parameters['ezstitch_first_row'])) self.last_row_entry.setText(str(self.parameters['ezstitch_last_row'])) self.column_of_axis_entry.setText(str(self.parameters['ezstitch_axis_of_rotation'])) def set_rButton(self): if self.interpolate_regions_rButton.isChecked(): LOG.debug("Interpolate regions") self.parameters['ezstitch_stitch_type'] = 0 elif self.concatenate_rButton.isChecked(): LOG.debug("Concatenate only") self.parameters['ezstitch_stitch_type'] = 1 elif self.half_acquisition_rButton.isChecked(): LOG.debug("Half-acquisition mode") self.parameters['ezstitch_stitch_type'] = 2 def input_button_pressed(self): LOG.debug("Input button pressed") dir_explore = QFileDialog(self) self.parameters['ezstitch_input_dir'] = dir_explore.getExistingDirectory() self.input_dir_entry.setText(self.parameters['ezstitch_input_dir']) def set_input_entry(self): LOG.debug("Input: " + str(self.input_dir_entry.text())) self.parameters['ezstitch_input_dir'] = str(self.input_dir_entry.text()) def temp_button_pressed(self): LOG.debug("Temp button pressed") dir_explore = QFileDialog(self) self.parameters['ezstitch_temp_dir'] = dir_explore.getExistingDirectory() self.tmp_dir_entry.setText(self.parameters['ezstitch_temp_dir']) def set_temp_entry(self): LOG.debug("Temp: " + str(self.tmp_dir_entry.text())) self.parameters['ezstitch_temp_dir'] = str(self.tmp_dir_entry.text()) def output_button_pressed(self): LOG.debug("Output button pressed") dir_explore = QFileDialog(self) self.parameters['ezstitch_output_dir'] = dir_explore.getExistingDirectory() self.output_dir_entry.setText(self.parameters['ezstitch_output_dir']) def set_output_entry(self): LOG.debug("Output: " + str(self.output_dir_entry.text())) self.parameters['ezstitch_output_dir'] = str(self.output_dir_entry.text()) def set_type_images(self): LOG.debug("Type of images: " + str(self.types_of_images_entry.text())) self.parameters['ezstitch_type_image'] = str(self.types_of_images_entry.text()) def set_stitch_checkbox(self): LOG.debug("Stitch orthogonal: " + str(self.orthogonal_checkbox.isChecked())) self.parameters['ezstitch_stitch_orthogonal'] = bool(self.orthogonal_checkbox.isChecked()) def set_start_stop_step(self): LOG.debug("Images to be stitched: " + str(self.start_stop_step_entry.text())) self.parameters['ezstitch_start_stop_step'] = str(self.start_stop_step_entry.text()) def set_sample_moved_down(self): LOG.debug("Sample moved down: " + str(self.sample_moved_down_checkbox.isChecked())) self.parameters['ezstitch_sample_moved_down'] = bool(self.sample_moved_down_checkbox.isChecked()) def set_overlap(self): LOG.debug("Num overlapping rows: " + str(self.num_overlaps_entry.text())) self.parameters['ezstitch_num_overlap_rows'] = int(self.num_overlaps_entry.text()) def set_histogram_checkbox(self): LOG.debug("Clip histogram: " + str(self.clip_histogram_checkbox.isChecked())) self.parameters['ezstitch_clip_histo'] = bool(self.clip_histogram_checkbox.isChecked()) def set_min_value(self): LOG.debug("Min value: " + str(self.min_value_entry.text())) self.parameters['ezstitch_histo_min'] = float(self.min_value_entry.text()) def set_max_value(self): LOG.debug("Max value: " + str(self.max_value_entry.text())) self.parameters['ezstitch_histo_max'] = float(self.max_value_entry.text()) def set_first_row(self): LOG.debug("First row: " + str(self.first_row_entry.text())) self.parameters['ezstitch_first_row'] = int(self.first_row_entry.text()) def set_last_row(self): LOG.debug("Last row: " + str(self.last_row_entry.text())) self.parameters['ezstitch_last_row'] = int(self.last_row_entry.text()) def set_axis_column(self): LOG.debug("Column of axis: " + str(self.column_of_axis_entry.text())) self.parameters['ezstitch_axis_of_rotation'] = int(self.column_of_axis_entry.text()) def help_button_pressed(self): LOG.debug("Help button pressed") h = "Stitches images vertically\n" h += "Directory structure is, f.i., Input/000, Input/001,...Input/00N\n" h += "Each 000, 001, ... 00N directory must have identical subdirectory \"Type\"\n" h += "Selected range of images from \"Type\" directory will be stitched vertically\n" h += "across all subdirectories in the Input directory" h += "to be added as options:\n" h += "(1) orthogonal reslicing, (2) interpolation, (3) horizontal stitching" QMessageBox.information(self, "Help", h) def delete_button_pressed(self): LOG.debug("Delete button pressed") # if os.path.exists(self.parameters['ezstitch_output_dir']): # os.system('rm -r {}'.format(self.parameters['ezstitch_output_dir'])) # print(" - Directory with reconstructed data was removed") if os.path.exists(self.parameters['ezstitch_output_dir']): qm = QMessageBox() rep = qm.question(self, '', f"{self.parameters['ezstitch_output_dir']} \n" "will be removed. Continue?", qm.Yes | qm.No) if rep == qm.Yes: try: rmtree(self.parameters['ezstitch_output_dir']) except: warning_message('Error while deleting directory') return else: return def stitch_button_pressed(self): LOG.debug("Stitch button pressed") if os.path.exists(self.parameters['ezstitch_temp_dir']): qm = QMessageBox() rep = qm.question(self, '', "Temporary dir is not empty. Is it safe to delete it?", qm.Yes | qm.No) if rep == qm.Yes: try: rmtree(self.parameters['ezstitch_temp_dir']) except: warning_message('Error while deleting directory') return else: return if os.path.exists(self.parameters['ezstitch_output_dir']): #raise ValueError('Output directory exists. Delete it or select another one.') warning_message('Output directory exists. Delete it or select another one.') return print("======= Begin Stitching =======") # Interpolate overlapping regions and equalize intensity if self.parameters['ezstitch_stitch_type'] == 0: main_sti_mp(self.parameters) # Concatenate only elif self.parameters['ezstitch_stitch_type'] == 1: main_conc_mp(self.parameters) # Half acquisition mode elif self.parameters['ezstitch_stitch_type'] == 2: main_360_mp_depth1(self.parameters['ezstitch_input_dir'], self.parameters['ezstitch_output_dir'], self.parameters['ezstitch_axis_of_rotation'], 0) if os.path.isdir(self.parameters['ezstitch_output_dir']): params_file_path = os.path.join(self.parameters['ezstitch_output_dir'], 'ezmview_params.yaml') params.save_parameters(self.parameters, params_file_path) print("==== Waiting for Next Task ====") def import_parameters_button_pressed(self): LOG.debug("Import params button clicked") dir_explore = QFileDialog(self) params_file_path = dir_explore.getOpenFileName(filter="*.yaml") try: file_in = open(params_file_path[0], 'r') new_parameters = yaml.load(file_in, Loader=yaml.FullLoader) if self.update_parameters(new_parameters) == 0: print("Parameters file loaded from: " + str(params_file_path[0])) except FileNotFoundError: print("You need to select a valid input file") def save_parameters_button_pressed(self): LOG.debug("Save params button clicked") dir_explore = QFileDialog(self) params_file_path = dir_explore.getSaveFileName(filter="*.yaml") garbage, file_name = os.path.split(params_file_path[0]) file_extension = os.path.splitext(file_name) # If the user doesn't enter the .yaml extension then append it to filepath if file_extension[-1] == "": file_path = params_file_path[0] + ".yaml" else: file_path = params_file_path[0] try: file_out = open(file_path, 'w') yaml.dump(self.parameters, file_out) print("Parameters file saved at: " + str(file_path)) except FileNotFoundError: print("You need to select a directory and use a valid file name") tofu-0.12.0/tofu/ez/GUI/__init__.py000066400000000000000000000000011423713721100166670ustar00rootroot00000000000000 tofu-0.12.0/tofu/ez/GUI/default_settings.yaml000066400000000000000000000060341423713721100210220ustar00rootroot00000000000000# Default configuration file for ez_ufo_qt # Modify at your own peril --- main_config_input_dir: "" main_config_temp_dir: "" main_config_output_dir: "" main_config_darks_dir_name: "darks" main_config_flats_dir_name: "flats" main_config_tomo_dir_name: "tomo" main_config_flats2_dir_name: "flats2" main_config_save_multipage_tiff: false main_cor_axis_search_method: 1 main_cor_axis_search_interval: "1010,1030,0.5" main_cor_search_row_start: 100 main_cor_recon_patch_size: 256 main_cor_axis_column: 0.0 main_cor_axis_increment_step: 0.0 main_filters_remove_spots: false main_filters_remove_spots_threshold: 1000 main_filters_remove_spots_blur_sigma: 2 main_filters_ring_removal: false main_filters_ring_removal_ufo_lpf: true main_filters_ring_removal_ufo_lpf_1d_or_2d: true main_filters_ring_removal_ufo_lpf_sigma_horizontal: 3 main_filters_ring_removal_ufo_lpf_sigma_vertical: 1 main_filters_ring_removal_sarepy_window_size: 21 main_filters_ring_removal_sarepy_wide: false main_filters_ring_removal_sarepy_window: 91 main_filters_ring_removal_sarepy_SNR: 3 main_pr_phase_retrieval: false main_pr_photon_energy: 20 main_pr_pixel_size: 3.6 main_pr_detector_distance: 0.1 main_pr_delta_beta_ratio: 200 main_region_select_rows: false main_region_first_row: 100 main_region_number_rows: 200 main_region_nth_row: 20 main_region_clip_histogram: false main_region_bit_depth: 8 main_region_histogram_min: 0.0 main_region_histogram_max: 0.0 main_config_preprocess: false main_config_preprocess_command: "remove-outliers size=3 threshold=500 sign=1" main_region_rotate_volume_clock: 0.0 main_region_crop_slices: false main_region_crop_x: 0 main_region_crop_width: 0 main_region_crop_y: 0 main_region_crop_height: 0 main_config_dry_run: false main_config_save_params: true main_config_keep_temp: false advanced_ffc_sinFFC: false advanced_ffc_method: 1 advanced_ffc_eigen_pco_reps: 4 advanced_ffc_eigen_pco_downsample: 2 advanced_ffc_downsample: 4 main_config_open_viewer: True main_config_common_flats_darks: False main_config_darks_path: "Absolute path to darks" main_config_flats_path: "Absolute path to flats" main_config_flats2_checkbox: False main_config_flats2_path: "Absolute path to flats2" #NLMDN Settings advanced_nlmdn_apply_after_reco: False advanced_nlmdn_input_dir: "" advanced_nlmdn_input_is_file: False advanced_nlmdn_output_dir: "" advanced_nlmdn_save_bigtiff: False advanced_nlmdn_sim_search_radius: 10 advanced_nlmdn_patch_radius: 3 advanced_nlmdn_smoothing_control: 0.0 advanced_nlmdn_noise_std: 0.0 advanced_nlmdn_window: 0.0 advanced_nlmdn_fast: True advanced_nlmdn_estimate_sigma: False advanced_nlmdn_dry_run: False #ADVANCED TOFU Settings advanced_advtofu_lamino_angle: 30 advanced_adv_tofu_z_axis_rotation: 360 advanced_advtofu_center_position_z: "" advanced_advtofu_y_axis_rotation: "" advanced_advtofu_aux_ffc_dark_scale: "" advanced_advtofu_aux_ffc_flat_scale: "" #Optimization Settings advanced_advtofu_extended_settings: False advanced_optimize_verbose_console: False advanced_optimize_slice_mem_coeff: 0.5 advanced_optimize_num_gpus: "" advanced_optimize_slices_per_device: "" ...tofu-0.12.0/tofu/ez/GUI/ezufo_launcher.py000066400000000000000000000253551423713721100201640ustar00rootroot00000000000000import logging import os import sys from PyQt5 import QtWidgets as qtw from tofu.ez.GUI.Main.centre_of_rotation import CentreOfRotationGroup from tofu.ez.GUI.Main.filters import FiltersGroup from tofu.ez.GUI.Advanced.ffc import FFCGroup from tofu.ez.GUI.Main.phase_retrieval import PhaseRetrievalGroup from tofu.ez.GUI.Main.region_and_histogram import ROIandHistGroup from tofu.ez.GUI.Main.config import ConfigGroup from tofu.ez.main import clean_tmp_dirs from tofu.ez.yaml_in_out import Yaml_IO from tofu.ez.GUI.image_viewer import ImageViewerGroup import tofu.ez.params as parameters from tofu.ez.GUI.Advanced.advanced import AdvancedGroup from tofu.ez.GUI.Advanced.optimization import OptimizationGroup from tofu.ez.GUI.Advanced.nlmdn import NLMDNGroup from tofu.ez.GUI.Stitch_tools_tab.ez_360_multi_stitch_qt import MultiStitch360Group from tofu.ez.GUI.Stitch_tools_tab.ezstitch_qt import EZStitchGroup from tofu.ez.GUI.Stitch_tools_tab.ezmview_qt import EZMViewGroup from tofu.ez.GUI.Stitch_tools_tab.ez_360_overlap_qt import Overlap360Group from tofu.ez.GUI.login_dialog import Login LOG = logging.getLogger(__name__) class GUI(qtw.QWidget): """ Creates main GUI """ def __init__(self, *args, **kwargs): super(GUI, self).__init__(*args, **kwargs) self.setWindowTitle("EZ-UFO") self.setStyleSheet("font: 10pt; font-family: Arial") # Call login dialog # self.login_parameters = {} # QTimer.singleShot(0, self.login) # Read in default parameter settings from yaml file try: settings_path = os.path.dirname(os.path.abspath(__file__)) + "/default_settings.yaml" self.yaml_io = Yaml_IO() self.yaml_data = self.yaml_io.read_yaml(settings_path) parameters.params = dict(self.yaml_data) parameters.params["parameters_type"] = "ez_ufo_reco" except FileNotFoundError: print("Could not load default settings from: " + str(settings_path)) # Initialize tab screen self.tabs = qtw.QTabWidget() self.tab1 = qtw.QWidget() self.tab2 = qtw.QWidget() self.tab3 = qtw.QWidget() self.tab4 = qtw.QWidget() # Create and setup classes for each section of GUI # Main Tab self.centre_of_rotation_group = CentreOfRotationGroup() self.centre_of_rotation_group.init_values() self.filters_group = FiltersGroup() self.filters_group.init_values() self.ffc_group = FFCGroup() self.ffc_group.init_values() self.phase_retrieval_group = PhaseRetrievalGroup() self.phase_retrieval_group.init_values() self.binning_group = ROIandHistGroup() self.binning_group.init_values() self.config_group = ConfigGroup() self.config_group.init_values() # Image Viewer self.image_group = ImageViewerGroup() # Advanced Tab self.advanced_group = AdvancedGroup() self.advanced_group.init_values() self.optimization_group = OptimizationGroup() self.optimization_group.init_values() self.nlmdn_group = NLMDNGroup() self.nlmdn_group.init_values() # Stitch_tools_tab Tab self.multi_stitch_group = MultiStitch360Group() self.multi_stitch_group.init_values() self.ezmview_group = EZMViewGroup() self.ezmview_group.init_values() self.ezstitch_group = EZStitchGroup() self.ezstitch_group.init_values() self.overlap_group = Overlap360Group() self.overlap_group.init_values() ####################################################### self.set_layout() self.resize(0, 0) # window to minimum size # When new settings are imported signal is sent and this catches it to update params for each GUI object self.config_group.signal_update_vals_from_params.connect(self.update_values_from_params) # When RECO is done send signal from config self.config_group.signal_reco_done.connect(self.switch_to_image_tab) # To pass directory names from config tab to stitch tab when button pressed self.multi_stitch_group.get_fdt_names_on_stitch_pressed.connect(self.config_group.set_fdt_names) self.overlap_group.get_fdt_names_on_stitch_pressed.connect(self.config_group.set_fdt_names) finish = qtw.QAction("Quit", self) finish.triggered.connect(self.closeEvent) self.show() def set_layout(self): """ Set the layout of groups/tabs for the overall application layout """ layout = qtw.QVBoxLayout(self) main_layout = qtw.QGridLayout() main_layout.addWidget(self.centre_of_rotation_group, 0, 0) main_layout.addWidget(self.filters_group, 0, 1) main_layout.addWidget(self.phase_retrieval_group, 1, 0) main_layout.addWidget(self.binning_group, 1, 1) main_layout.addWidget(self.config_group, 2, 0, 2, 0) image_layout = qtw.QGridLayout() image_layout.addWidget(self.image_group, 0, 0) advanced_layout = qtw.QGridLayout() advanced_layout.addWidget(self.ffc_group, 0, 0) advanced_layout.addWidget(self.advanced_group, 1, 0) advanced_layout.addWidget(self.optimization_group, 1, 1) advanced_layout.addWidget(self.nlmdn_group, 2, 0) helpers_layout = qtw.QGridLayout() helpers_layout.addWidget(self.ezmview_group, 0, 0) helpers_layout.addWidget(self.overlap_group, 0, 1) helpers_layout.addWidget(self.multi_stitch_group, 1, 0) helpers_layout.addWidget(self.ezstitch_group, 1, 1) # Add tabs self.tabs.addTab(self.tab1, "Main") self.tabs.addTab(self.tab2, "Advanced") self.tabs.addTab(self.tab3, "Stitching tools") self.tabs.addTab(self.tab4, "Image Viewer") # Create main tab self.tab1.layout = main_layout self.tab1.setLayout(self.tab1.layout) # Create image tab self.tab4.layout = image_layout self.tab4.setLayout(self.tab4.layout) # Create advanced tab self.tab2.layout = advanced_layout self.tab2.setLayout(self.tab2.layout) # Create helpers tab self.tab3.layout = helpers_layout self.tab3.setLayout(self.tab3.layout) # Add tabs to widget layout.addWidget(self.tabs) self.setLayout(layout) def update_values_from_params(self): """ Updates displayed values when loaded in from external .yaml file of parameters """ LOG.debug("Update Values from Params") LOG.debug(parameters.params) self.centre_of_rotation_group.set_values_from_params() self.filters_group.set_values_from_params() self.ffc_group.set_values_from_params() self.phase_retrieval_group.set_values_from_params() self.binning_group.set_values_from_params() self.config_group.set_values_from_params() self.nlmdn_group.set_values_from_params() self.advanced_group.set_values_from_params() self.optimization_group.set_values_from_params() def switch_to_image_tab(self): """ Function is called after reconstruction when checkbox "Load images and open viewer after reconstruction" is enabled Automatically loads images from the output reconstruction directory for viewing """ if parameters.params["main_config_open_viewer"] is True: LOG.debug("Switch to Image Tab") self.tabs.setCurrentWidget(self.tab2) if os.path.isdir(str(parameters.params['main_config_output_dir'] + '/sli')): files = os.listdir(str(parameters.params['main_config_output_dir'] + '/sli')) #Start thread here to load images ##CHECK IF ONLY SINGLE IMAGE THEN USE OPEN IMAGE -- OTHERWISE OPEN STACK if len(files) == 1: print("Only one file in {}: Opening single image {}". format(parameters.params['main_config_output_dir'] + '/sli', files[0])) filePath = str(parameters.params['main_config_output_dir'] + '/sli/' + str(files[0])) self.image_group.open_image_from_filepath(filePath) else: print("Multiple files in {}: Opening stack of images". format(str(parameters.params['main_config_output_dir'] + '/sli'))) self.image_group.open_stack_from_path( str(parameters.params['main_config_output_dir'] + '/sli')) else: print("No output directory found") def closeEvent(self, event): """ Creates verification message box Cleans up temporary directories when user quits application """ logging.debug("QUIT") reply = qtw.QMessageBox.question(self, 'Quit', 'Are you sure you want to quit?', qtw.QMessageBox.Yes | qtw.QMessageBox.No, qtw.QMessageBox.No) if reply == qtw.QMessageBox.Yes: # remove all directories with projections clean_tmp_dirs(parameters.params['main_config_temp_dir'], self.config_group.get_fdt_names()) # remove axis-search dir too tmp = os.path.join(parameters.params['main_config_temp_dir'], 'axis-search') event.accept() else: event.ignore() def login(self): login_dialog = Login(self.login_parameters) if login_dialog.exec_() != qtw.QDialog.Accepted: self.exit() else: #self.file_writer_group.root_dir_entry.setText(self.login_parameters['expdir']) self.config_group.input_dir_entry.setText(self.login_parameters['expdir'] + "/raw") self.config_group.set_input_dir() self.config_group.output_dir_entry.setText(self.login_parameters['expdir'] + "/rec") self.config_group.set_output_dir() ''' td = date.today() tdstr = "{}.{}.{}".format(td.year, td.month, td.day) logfname = os.path.join(self.login_parameters['expdir'], 'exp-log-' + tdstr + '.log') if self.login_parameters.has_key('project'): logfname = os.path.join(self.login_parameters['expdir'], '{}-log-{}-{}.log'. format(self.login_parameters['project'], self.login_parameters['bl'], tdstr)) try: open(logfname, 'a').close() except: warning_message('Cannot create log file in the selected directory. \n' 'Check permissions and restart.') self.exit() ''' def exit(self): self.close() def main_qt(args=None): app = qtw.QApplication(sys.argv) window = GUI() sys.exit(app.exec_()) if __name__ == "__main__": main_qt() tofu-0.12.0/tofu/ez/GUI/image_viewer.py000066400000000000000000000357521423713721100176200ustar00rootroot00000000000000import os import logging import pyqtgraph as pg import numpy as np import tifffile from PyQt5.QtWidgets import ( QPushButton, QGroupBox, QLabel, QDoubleSpinBox, QRadioButton, QScrollBar, QVBoxLayout, QGridLayout, QFileDialog, QMessageBox, ) from PyQt5.QtCore import Qt import tofu.ez.image_read_write as image_read_write #TODO Integrate axis search tab ob tofu gui into this interface LOG = logging.getLogger(__name__) class ImageViewerGroup(QGroupBox): def __init__(self): super().__init__() #TODO: initialize on every opening with explicit data type #mmatching the data format being opened. #must check that there is enough RAM before loading!! self.tiff_arr = np.empty([0, 0, 0]) # float32 self.img_arr = np.empty([0, 0]) self.bit_depth = 32 self.open_file_button = QPushButton("Open Image File") self.open_file_button.clicked.connect(self.open_image_from_file) self.open_file_button.setStyleSheet("background-color: lightgrey; font: 11pt") self.open_stack_button = QPushButton("Open Image Stack") self.open_stack_button.clicked.connect(self.open_stack_from_directory) self.open_stack_button.setStyleSheet("background-color: lightgrey; font: 11pt") self.save_file_button = QPushButton("Save Image File") self.save_file_button.clicked.connect(self.save_image_to_file) self.save_file_button.setStyleSheet("background-color: lightgrey; font: 11pt") self.save_stack_button = QPushButton("Save Image Stack") self.save_stack_button.clicked.connect(self.save_stack_to_directory) self.save_stack_button.setStyleSheet("background-color: lightgrey; font: 11pt") self.open_big_tiff_button = QPushButton("Open BigTiff") self.open_big_tiff_button.clicked.connect(self.open_big_tiff) self.open_big_tiff_button.setStyleSheet("background-color: lightgrey; font: 11pt") self.save_big_tiff_button = QPushButton("Save BigTiff") self.save_big_tiff_button.clicked.connect(self.save_stack_to_big_tiff) self.save_big_tiff_button.setStyleSheet("background-color: lightgrey; font: 11pt") self.save_8bit_rButton = QRadioButton() self.save_8bit_rButton.setText("Save as 8-bit") self.save_8bit_rButton.clicked.connect(self.set_8bit) self.save_8bit_rButton.setChecked(False) self.save_16bit_rButton = QRadioButton() self.save_16bit_rButton.setText("Save as 16-bit") self.save_16bit_rButton.clicked.connect(self.set_16bit) self.save_16bit_rButton.setChecked(False) self.save_32bit_rButton = QRadioButton() self.save_32bit_rButton.setText("Save as 32-bit") self.save_32bit_rButton.clicked.connect(self.set_32bit) self.save_32bit_rButton.setChecked(True) self.hist_min_label = QLabel("Histogram Min:") self.hist_min_label.setAlignment(Qt.AlignRight | Qt.AlignVCenter) self.hist_max_label = QLabel("Histogram Max:") self.hist_max_label.setAlignment(Qt.AlignRight | Qt.AlignVCenter) self.hist_min_input = QDoubleSpinBox() self.hist_min_input.setDecimals(12) self.hist_min_input.setRange(-10, 10) self.hist_min_input.valueChanged.connect(self.min_spin_changed) self.hist_max_input = QDoubleSpinBox() self.hist_max_input.setDecimals(12) self.hist_max_input.setRange(-10, 10) self.hist_max_input.valueChanged.connect(self.max_spin_changed) self.apply_histogram_button = QPushButton("Apply Histogram to Image Stack") self.apply_histogram_button.clicked.connect(self.apply_histogram_button_clicked) self.image_window = pg.ImageView() self.image_window.ui.histogram.gradient.hide() self.histo = self.image_window.getHistogramWidget() self.scroller = QScrollBar(Qt.Horizontal) self.scroller.orientation() self.scroller.setEnabled(False) self.scroller.valueChanged.connect(self.scroll_changed) self.set_layout() def set_layout(self): vbox = QVBoxLayout() vbox.addWidget(self.save_8bit_rButton) vbox.addWidget(self.save_16bit_rButton) vbox.addWidget(self.save_32bit_rButton) gridbox = QGridLayout() gridbox.addWidget(self.hist_max_label, 0, 0) gridbox.addWidget(self.hist_max_input, 0, 1) gridbox.addWidget(self.hist_min_label, 1, 0) gridbox.addWidget(self.hist_min_input, 1, 1) layout = QGridLayout() layout.addWidget(self.open_file_button, 0, 0) layout.addWidget(self.save_file_button, 1, 0) layout.addWidget(self.open_stack_button, 0, 1) layout.addWidget(self.save_stack_button, 1, 1) layout.addWidget(self.open_big_tiff_button, 0, 2) layout.addWidget(self.save_big_tiff_button, 1, 2) layout.addItem(vbox, 0, 3, 2, 1) layout.addItem(gridbox, 0, 4, 2, 1) layout.addWidget(self.apply_histogram_button, 0, 5) layout.addWidget(self.image_window, 2, 0, 1, 6) layout.addWidget(self.scroller, 4, 0, 1, 5) self.setLayout(layout) self.resize(640, 480) self.show() def scroll_changed(self): """ Updated the currently displayed image based on position of scroll bar :return: None """ self.image_window.setImage(self.tiff_arr[self.scroller.value()].T) def open_image_from_file(self): """ Opens and displays a single image (.tif) specified by the user in the file dialog :return: None """ LOG.debug("Open image button pressed") options = QFileDialog.Options() filePath, _ = QFileDialog.getOpenFileName( self, "Open .tif Image File", "", "Tiff Files (*.tif *.tiff)", options=options ) if filePath: LOG.debug("Import image path: " + filePath) self.img_arr = image_read_write.read_image(filePath) self.image_window.setImage(self.img_arr.T) self.scroller.setEnabled(False) def open_image_from_filepath(self, filePath): """ Opens and displays a single image (.tif) contained in a directory - (used when one slice is reconstructed) :param filePath: Full path and filename :return: None """ LOG.debug("Open image from filepath: " + str(filePath)) if filePath: LOG.debug("Import image path: " + filePath) self.img_arr = image_read_write.read_image(filePath) self.image_window.setImage(self.img_arr.T) self.scroller.setEnabled(False) def save_image_to_file(self): """ Saves the currently displayed image to a file (.tif) specified by the user in the file dialog :return: None """ LOG.debug("Save image to file") options = QFileDialog.Options() filepath, _ = QFileDialog.getSaveFileName( self, "QFileDialog.getSaveFileName()", "", "Tiff Files (*.tif *.tiff)", options=options ) if filepath: LOG.debug(filepath) bit_depth_string = self.check_bit_depth(self.bit_depth) img = self.image_window.imageItem.qimage # https://www.programmersought.com/article/73475006380/ size = img.size() s = img.bits().asstring( size.width() * size.height() * img.depth() // 8 ) # format 0xffRRGGBB arr = np.fromstring(s, dtype=np.uint8).reshape( (size.height(), size.width(), img.depth() // 8) ) image_read_write.write_image( arr.T[0].T, os.path.dirname(filepath), os.path.basename(filepath), bit_depth_string ) def open_stack_from_directory(self): """ Opens all images (.tif) in a directory and displays them. Allows for scrolling through images with slider :return: None """ LOG.debug("Open image stack button pressed") dir_explore = QFileDialog() directory = dir_explore.getExistingDirectory() if directory: try: tiff_list = (".tif", ".tiff") msg = QMessageBox() msg.setIcon(QMessageBox.Information) msg.setWindowTitle("Loading Images...") msg.setText("Loading Images from Directory") msg.show() self.tiff_arr = image_read_write.read_all_images(directory, tiff_list) self.scroller.setRange(0, self.tiff_arr.shape[0] - 1) self.scroller.setEnabled(True) self.image_window.setImage(self.tiff_arr[0].T) msg.close() mid_index = self.tiff_arr.shape[0] // 2 self.scroller.setValue(mid_index) except image_read_write.InvalidDataSetError: print("Invalid Data Set") def open_stack_from_path(self, dir_path: str): """ Read images (.tif) from directory path into RAM as 3D numpy array :param dir_path: Path to directory containing multiple .tiff image files """ LOG.debug("Open stack from path") try: tiff_list = (".tif", ".tiff") msg = QMessageBox() msg.setIcon(QMessageBox.Information) msg.setWindowTitle("Loading Images...") msg.setText("Loading Images from Directory") msg.show() self.tiff_arr = image_read_write.read_all_images(dir_path, tiff_list) self.scroller.setRange(0, self.tiff_arr.shape[0] - 1) self.scroller.setEnabled(True) self.image_window.setImage(self.tiff_arr[0].T) msg.close() mid_index = self.tiff_arr.shape[0] // 2 self.scroller.setValue(mid_index) except image_read_write.InvalidDataSetError: print("Invalid Data Set") def save_stack_to_directory(self): """ Saves images stored in numpy array to individual files (.tif) in directory specified by user dialog Saves these images as BigTiff if checkbox is set to True """ LOG.debug("Save stack to directory button pressed") LOG.debug("Saving with bitdepth: " + str(self.bit_depth)) dir_explore = QFileDialog() directory = dir_explore.getExistingDirectory() LOG.debug("Writing to directory: " + directory) if directory: bit_depth_string = self.check_bit_depth(self.bit_depth) msg = QMessageBox() msg.setIcon(QMessageBox.Information) msg.setWindowTitle("Saving Images...") msg.setText("Saving Images to Directory") msg.show() self.apply_histogram_to_images() image_read_write.write_all_images(self.tiff_arr, directory, bit_depth_string) msg.close() def open_big_tiff(self): """ Opens images stored in a big tiff file (.tif) and displays them. Allows user to view them using scrollbar. :return: None """ LOG.debug("Open big tiff button pressed") options = QFileDialog.Options() filePath, _ = QFileDialog.getOpenFileName( self, "QFileDialog.getOpenFileName()", "", "All Files (*)", options=options ) if filePath: LOG.debug("Import image path: " + filePath) msg = QMessageBox() msg.setIcon(QMessageBox.Information) msg.setWindowTitle("Loading Images...") msg.setText("Loading Images from BigTiff") msg.show() self.tiff_arr = tifffile.imread(filePath).astype(dtype=np.float32) self.scroller.setRange(0, self.tiff_arr.shape[0] - 1) self.scroller.setEnabled(True) self.image_window.setImage(self.tiff_arr[0].T) msg.close() mid_index = self.tiff_arr.shape[0] // 2 self.scroller.setValue(mid_index) def save_stack_to_big_tiff(self): """ Saves the stack of images currently loaded into RAM to a single bigtif file :return: None """ LOG.debug("Save stack to bigtiff button pressed") LOG.debug("Saving with bitdepth: " + str(self.bit_depth)) dir_explore = QFileDialog() options = QFileDialog.Options() filepath, _ = QFileDialog.getSaveFileName( self, "QFileDialog.getSaveFileName()", "", "Tiff Files (*.tif *.tiff)", options=options ) if filepath: msg = QMessageBox() msg.setIcon(QMessageBox.Information) msg.setWindowTitle("Saving Images...") msg.setText("Saving Images to BigTiff") msg.show() # self.apply_histogram_to_images() bit_depth_string = self.check_bit_depth(self.bit_depth) tifffile.imwrite(filepath, self.tiff_arr, bigtiff=True, dtype=bit_depth_string) msg.close() def min_spin_changed(self): """ Changes the levels of the histogram widget if the min spinbox has been changed :return: None """ histo = self.image_window.getHistogramWidget() levels = histo.getLevels() min_level = self.hist_min_input.value() self.image_window.setLevels(min_level, levels[1]) def max_spin_changed(self): """ Changes the levels of the histogram widget if the max spinbox has been changed :return: None """ histo = self.image_window.getHistogramWidget() levels = histo.getLevels() max_level = self.hist_max_input.value() self.image_window.setLevels(levels[0], max_level) def apply_histogram_button_clicked(self): LOG.debug("Apply Histogram Button Clicked") print("Applying histogram to images. This may take a moment.") self.apply_histogram_to_images() def apply_histogram_to_images(self): """ Gets the histogram levels of the currently displayed image and applies them to all images in RAM :return: None """ levels = self.histo.getLevels() self.tiff_arr = np.clip(self.tiff_arr, levels[0], levels[1]) def check_bit_depth(self, bit_depth: int) -> str: """ Returns a string indicating the bitdepth to store the images based on value of bit-depth radio buttons :param bit_depth: :return: String specifying datatype for numpy array """ if bit_depth == 8: return "uint8" elif bit_depth == 16: return "uint16" elif bit_depth == 32: return "uint32" def set_8bit(self): """ Sets value of bit_depth variable based on radio button selection :return: None """ LOG.debug("Set 8-bit") self.bit_depth = 8 def set_16bit(self): """ Sets value of bit_depth variable based on radio button selection :return: None """ LOG.debug("Set 16-bit") self.bit_depth = 16 def set_32bit(self): """ Sets value of bit_depth variable based on radio button selection :return: None """ LOG.debug("Set 32-bit") self.bit_depth = 32 tofu-0.12.0/tofu/ez/GUI/login_dialog.py000066400000000000000000000124521423713721100175740ustar00rootroot00000000000000import re from PyQt5.QtCore import Qt from PyQt5.QtWidgets import ( QDialog, QLineEdit, QPushButton, QLabel, QGridLayout, QFileDialog, QComboBox, ) from tofu.ez.GUI.message_dialog import error_message import os class Login(QDialog): def __init__(self, login_parameters_dict, **kwargs): super(Login, self).__init__(**kwargs) # Pass a method from main GUI self.login_parameters_dict = login_parameters_dict self.setWindowTitle("USER LOGIN") self.setWindowModality(Qt.ApplicationModal) self.setAttribute(Qt.WA_DeleteOnClose) self.welcome_label = QLabel() self.welcome_label.setText("Welcome to BMIT!") self.prompt_label_bl = QLabel() self.prompt_label_bl.setText("Please select the beamline and project:") self.bl_label = QLabel() self.bl_label.setText("Beamline:") self.bl_entry = QComboBox() self.bl_entry.addItems(["BM", "ID"]) self.proj_label = QLabel() self.proj_label.setText("Project:") self.proj_entry = QLineEdit() self.prompt_label_expdir = QLabel() self.prompt_label_expdir.setText("OR select the path to the working directory") self.expdir_entry = QLineEdit() # self.expdir_entry.setText("/data/gui-test") self.expdir_entry.setReadOnly(True) self.expdir_select_button = QPushButton("...") self.expdir_select_button.clicked.connect(self.select_expdir_func) self.login_button = QPushButton("LOGIN") self.login_button.clicked.connect(self.on_login_button_clicked) self.set_layout() def set_layout(self): layout = QGridLayout() self.welcome_label.setAlignment(Qt.AlignCenter) self.prompt_label_bl.setAlignment(Qt.AlignCenter) self.prompt_label_expdir.setAlignment(Qt.AlignCenter) layout.addWidget(self.welcome_label, 0, 0, 1, 2) layout.addWidget(self.prompt_label_bl, 1, 0, 1, 2) layout.addWidget(self.bl_label, 2, 0, 1, 1) layout.addWidget(self.bl_entry, 2, 1, 1, 1) layout.addWidget(self.proj_label, 3, 0, 1, 1) layout.addWidget(self.proj_entry, 3, 1, 1, 1) layout.addWidget(self.prompt_label_expdir, 4, 0, 1, 2) layout.addWidget(self.expdir_entry, 5, 0, 1, 1) layout.addWidget(self.expdir_select_button, 5, 1, 1, 1) layout.addWidget(self.login_button, 6, 0, 1, 2) layout.setSpacing(15) layout.setContentsMargins(25, 25, 25, 25) self.setLayout(layout) def select_expdir_func(self): options = QFileDialog.Options() options |= QFileDialog.DontUseNativeDialog root_dir = QFileDialog.getExistingDirectory( self, "Select working directory", "/data/gui-test", options=options ) if root_dir: self.expdir_entry.setText(root_dir) def uppercase_project_entry(self): self.proj_entry.setText(self.proj_entry.text().upper()) def strip_spaces_from_user_entry(self): self.user_entry.setText(self.user_entry.text().replace(" ", "")) @property def project_name(self): return self.proj_entry.text() @property def user_name(self): return self.user_entry.text() @property def expdir_name(self): return self.expdir_entry.text() @property def bl_name(self): return self.bl_entry.currentText() def validate_entries(self): self.uppercase_project_entry() # self.strip_spaces_from_user_entry() project_valid = bool(re.match(r"^[0-9]{2}[A-Z][0-9]{5}$", self.project_name)) # username_valid = bool(re.match(r"^[a-zA-Z0-9]*$", self.user_name)) # return project_valid, username_valid return project_valid def validate_dir(self, pdr): return os.access(pdr, os.W_OK) def on_login_button_clicked(self): # project_valid, username_valid = self.validate_entries() if self.project_name != "": prj_dir_name = os.path.join( "/beamlinedata/BMIT/projects/prj" + self.project_name, "raw" ) project_valid = self.validate_entries() can_write = self.validate_dir(prj_dir_name) if project_valid and can_write: self.login_parameters_dict.update({"bl": self.bl_name}) self.login_parameters_dict.update({"project": self.project_name}) # add fileExistsError exception later in Py3 self.login_parameters_dict.update({"expdir": prj_dir_name}) self.accept() # elif not username_valid: # error_message("Username should be alpha-numeric ") elif not project_valid: error_message( "The project should be in format: CCTNNNNN, \n" "where CC is cycle number, " "T is one-letter type, " "and NNNNN is project number" ) elif not can_write: error_message("Cannot write in the project directory") elif self.expdir_name != "": if self.validate_dir(self.expdir_entry.text()): self.login_parameters_dict.update({"expdir": self.expdir_name}) self.accept() else: error_message("Cannot write in the selected directory") tofu-0.12.0/tofu/ez/GUI/message_dialog.py000066400000000000000000000006661423713721100201140ustar00rootroot00000000000000from PyQt5.QtWidgets import QMessageBox def message_dialog(window_title, message_text): alert = QMessageBox() alert.setWindowTitle(window_title) alert.setText(message_text) alert.exec_() def error_message(message_text): message_dialog("Error", message_text) def warning_message(message_text): message_dialog("Warning", message_text) def info_message(message_text): message_dialog("Info", message_text) tofu-0.12.0/tofu/ez/Helpers/000077500000000000000000000000001423713721100155455ustar00rootroot00000000000000tofu-0.12.0/tofu/ez/Helpers/__init__.py000066400000000000000000000000001423713721100176440ustar00rootroot00000000000000tofu-0.12.0/tofu/ez/Helpers/find_360_overlap.py000066400000000000000000000124201423713721100211560ustar00rootroot00000000000000""" This script takes as input a CT scan that has been collected in "half-acquisition" mode and produces a series of reconstructed slices, each of which are generated by cropping and concatenating opposing projections together over a range of "overlap" values (i.e. the pixel column at which the images are cropped and concatenated). The objective is to review this series of images to determine the pixel column at which the axis of rotation is located (much like the axis search function commonly used in reconstruction software). """ import os import numpy as np import tifffile from tofu.ez.image_read_write import TiffSequenceReader import tofu.ez.params as glob_parameters from tofu.ez.Helpers.stitch_funcs import findCTdirs, stitch_float32_output def extract_row(dir_name, row): tsr = TiffSequenceReader(dir_name) tmp = tsr.read(0) (N, M) = tmp.shape if (row < 0) or (row > N): row = N//2 num_images = tsr.num_images if num_images % 2 == 1: print(f"odd number of images ({num_images}) in {dir_name}, " f"discarding the last one before stitching pairs") num_images-=1 A = np.empty((num_images, M), dtype=np.uint16) for i in range(num_images): A[i, :] = tsr.read(i)[row, :] tsr.close() return A def find_overlap(parameters): print("Finding CTDirs...") ctdirs, lvl0 = findCTdirs(parameters['360overlap_input_dir'], glob_parameters.params['main_config_tomo_dir_name']) ctdirs.sort() print(ctdirs) # concatenate images end-to-end and generate a sinogram for ctset in ctdirs: print("Working on ctset:" + str(ctset)) index_dir = os.path.basename(os.path.normpath(ctset)) # loading: print(os.path.join(ctset, glob_parameters.params['main_config_flats_dir_name'])) try: row_flat = np.mean(extract_row( os.path.join(ctset, glob_parameters.params['main_config_flats_dir_name']), parameters['360overlap_row'])) except: print(f"Problem loading flats in {ctset}") continue try: row_dark = np.mean(extract_row( os.path.join(ctset, glob_parameters.params['main_config_darks_dir_name']), parameters['360overlap_row'])) except: print(f"Problem loading darks in {ctset}") continue try: row_tomo = extract_row( os.path.join(ctset, glob_parameters.params['main_config_tomo_dir_name']), parameters['360overlap_row']) except: print(f"Problem loading projections from " f"{os.path.join(ctset, glob_parameters.params['main_config_tomo_dir_name'])}") continue row_flat2 = None tmpstr = os.path.join(ctset, glob_parameters.params['main_config_flats2_dir_name']) if os.path.exists(tmpstr): try: row_flat2 = np.mean(extract_row(tmpstr, parameters['360overlap_row'])) except: print(f"Problem loading flats2 in {ctset}") (num_proj, M) = row_tomo.shape print('Flat-field correction...') # Flat-correction tmp_flat = np.tile(row_flat, (num_proj, 1)) if row_flat2 is not None: tmp_flat2 = np.tile(row_flat2, (num_proj, 1)) ramp = np.linspace(0, 1, num_proj) ramp = np.transpose(np.tile(ramp, (M, 1))) tmp_flat = tmp_flat * (1-ramp) + tmp_flat2 * ramp del ramp, tmp_flat2 tmp_dark = np.tile(row_dark, (num_proj, 1)) tomo_ffc = -np.log((row_tomo - tmp_dark)/np.float32(tmp_flat - tmp_dark)) del row_tomo, row_dark, row_flat, tmp_flat, tmp_dark np.nan_to_num(tomo_ffc, copy=False, nan=0.0, posinf=0.0, neginf=0.0) # create interpolated sinogram of flats on the # same row as we use for the projections, then flat/dark correction print('Creating stitched sinograms...') sin_tmp_dir = os.path.join(parameters['360overlap_temp_dir'], index_dir) os.makedirs(sin_tmp_dir) for axis in range(parameters['360overlap_lower_limit'], parameters['360overlap_upper_limit'], parameters['360overlap_increment']): cro = parameters['360overlap_upper_limit'] - axis A = stitch_float32_output( tomo_ffc[: num_proj//2, :], tomo_ffc[num_proj//2:, ::-1], axis, cro) tifffile.imsave(os.path.join( sin_tmp_dir, 'sin-axis-' + str(axis).zfill(4) + '.tif'), A.astype(np.float32)) # perform reconstructions for each sinogram and save to output folder print('Reconstructing slices...') reco_axis = M-parameters['360overlap_upper_limit'] cmd = f'tofu tomo --axis {reco_axis} --sinograms {sin_tmp_dir}' cmd +=' --output '+os.path.join(os.path.join( parameters['360overlap_output_dir'], f"{index_dir}-sli.tif")) print(cmd) os.system(cmd) print("Finished processing: " + str(index_dir)) print("********************DONE********************") #shutil.rmtree(parameters['360overlap_temp_dir']) print("Finished processing: " + str(parameters['360overlap_input_dir'])) tofu-0.12.0/tofu/ez/Helpers/mview_main.py000066400000000000000000000065421423713721100202610ustar00rootroot00000000000000#!/bin/python import os import numpy from tofu.util import get_filenames import re def check_folders(p, noflats2): if not os.path.exists(p): os.makedirs(p) tmp = p + "/darks" if not os.path.exists(tmp): os.makedirs(tmp) tmp = p + "/flats" if not os.path.exists(tmp): os.makedirs(tmp) if noflats2 == False: tmp = p + "/flats2" if not os.path.exists(tmp): os.makedirs(tmp) tmp = p + "/tomo" if not os.path.exists(tmp): os.makedirs(tmp) def rename_Andor(args): names = get_filenames(os.path.join(args.input, "*.tif")) maxnum = re.match(".*?([0-9]+)$", names[0][:-4]).group(1) n_dgts = len(maxnum) trnc_len = n_dgts + 4 prefix = names[0][:-trnc_len] maxnum = int(maxnum) for name in names: num = int(re.match(".*?([0-9]+)$", name[:-4]).group(1)) maxnum = num if (num > maxnum) else maxnum n_dgts = len(str(maxnum)) lin_fmt = prefix + "{:0" + str(n_dgts) + "}.tif" for name in names: num = re.match(".*?([0-9]+)$", name[:-4]).group(1) if name == lin_fmt.format(int(num)): continue else: cmd = "mv {} {}".format(name, lin_fmt.format(int(num))) os.system(cmd) def main_prep(args): if args.Andor: rename_Andor(args) frames = get_filenames(os.path.join(args.input, "*.tif")) nframes = len(frames) if nframes == 0: tmp = "Check INPUT directory: there are no tif files there" raise ValueError(tmp) # replace first frame with the second to get rid of # corrupted first file in the PCO Edge sequencies cmd = "rm {}; cp {} {}".format(frames[0], frames[1], frames[0]) os.system(cmd) FFinterval = args.nproj int_tot = args.nviews # (args.nproj/FFinterval)*args.nviews int_1view = 1.0 # args.nproj/FFinterval files_in_int = args.nflats + args.ndarks + FFinterval files_input = (args.nflats + args.ndarks + FFinterval) * int_tot if args.noflats2 == False: files_input += args.nflats + args.ndarks if files_input != nframes: tmp = ( "Sequence length (found {} files) does not match ".format(nframes) + "one calculated from input parameters " + "(expected {} files)".format(files_input) ) raise ValueError(tmp) for i in range(args.nviews): if args.nviews > 1: pout = os.path.join(args.output, "z{:02d}".format(i)) else: pout = args.output check_folders(pout, args.noflats2) # offset to heading flats and darks o = i * files_in_int for i in range(args.nflats): cmd = "mv {} {}/flats/".format(frames[o + i], pout) os.system(cmd) # print(cmd) o += args.nflats for i in range(args.ndarks): cmd = "mv {} {}/darks/".format(frames[o + i], pout) os.system(cmd) # print(cmd) o += args.ndarks for i in range(args.nproj): cmd = "mv {} {}/tomo/".format(frames[o + i], pout) os.system(cmd) # print(cmd) o += args.nproj if args.noflats2: continue for i in range(args.nflats): cmd = "cp {} {}/flats2/".format(frames[o + i], pout) os.system(cmd) # print(cmd) print("========== Done ==========") tofu-0.12.0/tofu/ez/Helpers/stitch_funcs.py000066400000000000000000000474221423713721100206240ustar00rootroot00000000000000""" Last modified on Apr 1, 2022 @author: sergei gasilov """ import glob import os import shutil import numpy as np import tifffile from tofu.util import read_image import multiprocessing as mp from functools import partial import re import warnings import time import tofu.ez.params as glob_parameters def findCTdirs(root: str, tomo_name: str): """ Walks directories rooted at "Input ctset" location Appends their absolute path to ctdir if they contain a ctset with same name as "tomo" entry in GUI """ lvl0 = os.path.abspath(root) ctdirs = [] for root, dirs, files in os.walk(lvl0): for name in dirs: if name == tomo_name: ctdirs.append(root) return ctdirs, lvl0 def prepare(parameters, dir_type: int, ctdir: str): """ :param parameters: GUI params :param dir_type 1 if CTDir containing Z00-Z0N slices - 2 if parent directory containing CTdirs each containing Z slices: :param ctdir Name of the ctdir - blank string if not using multiple ctdirs: :return: """ hmin, hmax = 0.0, 0.0 if parameters['ezstitch_clip_histo']: if parameters['ezstitch_histo_min'] == parameters['ezstitch_histo_max']: raise ValueError(' - Define hmin and hmax correctly in order to convert to 8bit') else: hmin, hmax = parameters['ezstitch_histo_min'], parameters['ezstitch_histo_max'] start, stop, step = [int(value) for value in parameters['ezstitch_start_stop_step'].split(',')] if not os.path.exists(parameters['ezstitch_output_dir']): os.makedirs(parameters['ezstitch_output_dir']) Vsteps = sorted(os.listdir(os.path.join(parameters['ezstitch_input_dir'], ctdir))) #determine input data type if dir_type == 1: tmp = os.path.join(parameters['ezstitch_input_dir'], Vsteps[0], parameters['ezstitch_type_image'], '*.tif') tmp = sorted(glob.glob(tmp))[0] indtype = type(read_image(tmp)[0][0]) elif dir_type == 2: tmp = os.path.join(parameters['ezstitch_input_dir'], ctdir, Vsteps[0], parameters['ezstitch_type_image'], '*.tif') tmp = sorted(glob.glob(tmp))[0] indtype = type(read_image(tmp)[0][0]) if parameters['ezstitch_stitch_orthogonal']: for vstep in Vsteps: if dir_type == 1: in_name = os.path.join(parameters['ezstitch_input_dir'], vstep, parameters['ezstitch_type_image']) out_name = os.path.join(parameters['ezstitch_temp_dir'], vstep, parameters['ezstitch_type_image'], 'sli-%04i.tif') elif dir_type == 2: in_name = os.path.join(parameters['ezstitch_input_dir'], ctdir, vstep, parameters['ezstitch_type_image']) out_name = os.path.join(parameters['ezstitch_temp_dir'], ctdir, vstep, parameters['ezstitch_type_image'], 'sli-%04i.tif') cmd = 'tofu sinos --projections {} --output {}'.format(in_name, out_name) cmd += " --y {} --height {} --y-step {}".format(start, stop-start, step) cmd += " --output-bytes-per-file 0" os.system(cmd) time.sleep(10) indir = parameters['ezstitch_temp_dir'] else: indir = parameters['ezstitch_input_dir'] return indir, hmin, hmax, start, stop, step, indtype def exec_sti_mp(start, step, N, Nnew, Vsteps, indir, dx, M, parameters, ramp, hmin, hmax, indtype, ctdir, dir_type, j): index = start+j*step Large = np.empty((Nnew*len(Vsteps)+dx, M), dtype=np.float32) for i, vstep in enumerate(Vsteps[:-1]): if dir_type == 1: tmp = os.path.join(indir, Vsteps[i], parameters['ezstitch_type_image'], '*.tif') tmp1 = os.path.join(indir, Vsteps[i+1], parameters['ezstitch_type_image'], '*.tif') elif dir_type == 2: tmp = os.path.join(indir, ctdir, Vsteps[i], parameters['ezstitch_type_image'], '*.tif') tmp1 = os.path.join(indir, ctdir, Vsteps[i + 1], parameters['ezstitch_type_image'], '*.tif') if parameters['ezstitch_stitch_orthogonal']: tmp = sorted(glob.glob(tmp))[j] tmp1 = sorted(glob.glob(tmp1))[j] else: tmp = sorted(glob.glob(tmp))[index] tmp1 = sorted(glob.glob(tmp1))[index] first = read_image(tmp) second = read_image(tmp1) # sample moved downwards if parameters['ezstitch_sample_moved_down']: first, second = np.flipud(first), np.flipud(second) k = np.mean(first[N - dx:, :]) / np.mean(second[:dx, :]) second = second * k a, b, c = i*Nnew, (i+1)*Nnew, (i+2)*Nnew Large[a:b, :] = first[:N-dx, :] Large[b:b+dx, :] = np.transpose(np.transpose(first[N-dx:, :])*(1 - ramp) + np.transpose(second[:dx, :]) * ramp) Large[b+dx:c+dx, :] = second[dx:, :] pout = os.path.join(parameters['ezstitch_output_dir'], ctdir, parameters['ezstitch_type_image']+'-sti-{:>04}.tif'.format(index)) if not parameters['ezstitch_clip_histo']: tifffile.imsave(pout, Large.astype(indtype)) else: Large = 255.0/(hmax-hmin) * (np.clip(Large, hmin, hmax) - hmin) tifffile.imsave(pout, Large.astype(np.uint8)) def main_sti_mp(parameters): #Check whether indir is CTdir or parent containing CTdirs #if indir + some z00 subdir + sli + *.tif does not exist then use original subdirs = sorted(os.listdir(parameters['ezstitch_input_dir'])) if os.path.exists(os.path.join(parameters['ezstitch_input_dir'], subdirs[0], parameters['ezstitch_type_image'])): dir_type = 1 ctdir = "" print(" - Using CT directory containing slices") if parameters['ezstitch_stitch_orthogonal']: print(" - Creating orthogonal sections") indir, hmin, hmax, start, stop, step, indtype = prepare(parameters, dir_type, "") dx = int(parameters['ezstitch_num_overlap_rows']) # second: stitch them Vsteps = sorted(os.listdir(indir)) tmp = glob.glob(os.path.join(indir, Vsteps[0], parameters['ezstitch_type_image'], '*.tif'))[0] first = read_image(tmp) N, M = first.shape Nnew = N - dx ramp = np.linspace(0, 1, dx) J = range((stop - start) // step) pool = mp.Pool(processes=mp.cpu_count()) # ??? IT was OK back in 2.7 but now can crash # if pool size is larger than array being multiprocessed? exec_func = partial(exec_sti_mp, start, step, N, Nnew, \ Vsteps, indir, dx, M, parameters, ramp, hmin, hmax, indtype, ctdir, dir_type) print(" - Adjusting and stitching") # start = time.time() pool.map(exec_func, J) print("========== Done ==========") else: second_subdirs = sorted(os.listdir(os.path.join(parameters['ezstitch_input_dir'], subdirs[0]))) if os.path.exists(os.path.join(parameters['ezstitch_input_dir'], subdirs[0], second_subdirs[0], parameters['ezstitch_type_image'])): print(" - Using parent directory containing CT directories, each of which contains slices") dir_type = 2 #For each subdirectory do the same thing for ctdir in subdirs: print("-> Working on " + str(ctdir)) if not os.path.exists(os.path.join(parameters['ezstitch_output_dir'], ctdir)): os.makedirs(os.path.join(parameters['ezstitch_output_dir'], ctdir)) if parameters['ezstitch_stitch_orthogonal']: print(" - Creating orthogonal sections") indir, hmin, hmax, start, stop, step, indtype = prepare(parameters, dir_type, ctdir) dx = int(parameters['ezstitch_num_overlap_rows']) # second: stitch them Vsteps = sorted(os.listdir(os.path.join(indir, ctdir))) tmp = glob.glob(os.path.join(indir, ctdir, Vsteps[0], parameters['ezstitch_type_image'], '*.tif'))[0] first = read_image(tmp) N, M = first.shape Nnew = N - dx ramp = np.linspace(0, 1, dx) J = range(int((stop - start) / step)) pool = mp.Pool(processes=mp.cpu_count()) exec_func = partial(exec_sti_mp, start, step, N, Nnew, \ Vsteps, indir, dx, M, parameters, ramp, hmin, hmax, indtype, ctdir, dir_type) print(" - Adjusting and stitching") # start = time.time() pool.map(exec_func, J) print("========== Done ==========") # Clear temp directory clear_tmp(parameters) else: print("Invalid input directory") complete_message() def make_buf(tmp, l, a, b): first = read_image(tmp) N, M = first[a:b, :].shape return np.empty((N*l, M), dtype=first.dtype), N, first.dtype def exec_conc_mp(start, step, example_im, l, parameters, zfold, indir, ctdir, j): index = start+j*step Large, N, dtype = make_buf(example_im, l, parameters['ezstitch_first_row'], parameters['ezstitch_last_row']) for i, vert in enumerate(zfold): tmp = os.path.join(indir, ctdir, vert, parameters['ezstitch_type_image'], '*.tif') if parameters['ezstitch_stitch_orthogonal']: fname=sorted(glob.glob(tmp))[j] else: fname=sorted(glob.glob(tmp))[index] frame = read_image(fname)[parameters['ezstitch_first_row']:parameters['ezstitch_last_row'], :] if parameters['ezstitch_sample_moved_down']: Large[i*N:N*(i+1), :] = np.flipud(frame) else: Large[i*N:N*(i+1), :] = frame pout = os.path.join(parameters['ezstitch_output_dir'], ctdir, parameters['ezstitch_type_image']+'-sti-{:>04}.tif'.format(index)) #print "input data type {:}".format(dtype) tifffile.imsave(pout, Large) def main_conc_mp(parameters): # Check whether indir is CTdir or parent containing CTdirs # if indir + some z00 subdir + sli + *.tif does not exist then use original subdirs = sorted(os.listdir(parameters['ezstitch_input_dir'])) if os.path.exists(os.path.join(parameters['ezstitch_input_dir'], subdirs[0], parameters['ezstitch_type_image'])): dir_type = 1 ctdir = "" print(" - Using CT directory containing slices") if parameters['ezstitch_stitch_orthogonal']: print(" - Creating orthogonal sections") #start = time.time() indir, hmin, hmax, start, stop, step, indtype = prepare(parameters, dir_type, ctdir) subdirs = [dI for dI in os.listdir(parameters['ezstitch_input_dir']) if os.path.isdir(os.path.join(parameters['ezstitch_input_dir'], dI))] zfold = sorted(subdirs) l = len(zfold) tmp = glob.glob(os.path.join(indir, zfold[0], parameters['ezstitch_type_image'], '*.tif')) J = range((stop-start)//step) pool = mp.Pool(processes=mp.cpu_count()) exec_func = partial(exec_conc_mp, start, step, tmp[0], l, parameters, zfold, indir, ctdir) print(" - Concatenating") #start = time.time() pool.map(exec_func, J) #print "Images stitched in {:.01f} sec".format(time.time()-start) print("============ Done ============") else: second_subdirs = sorted(os.listdir(os.path.join(parameters['ezstitch_input_dir'], subdirs[0]))) if os.path.exists(os.path.join(parameters['ezstitch_input_dir'], subdirs[0], second_subdirs[0], parameters['ezstitch_type_image'])): print(" - Using parent directory containing CT directories, each of which contains slices") dir_type = 2 for ctdir in subdirs: print(" == Working on " + str(ctdir) + " ==") if not os.path.exists(os.path.join(parameters['ezstitch_output_dir'], ctdir)): os.makedirs(os.path.join(parameters['ezstitch_output_dir'], ctdir)) if parameters['ezstitch_stitch_orthogonal']: print(" - Creating orthogonal sections") # start = time.time() indir, hmin, hmax, start, stop, step, indtype = prepare(parameters, dir_type, ctdir) zfold = sorted(os.listdir(os.path.join(indir, ctdir))) l = len(zfold) tmp = glob.glob(os.path.join(indir, ctdir, zfold[0], parameters['ezstitch_type_image'], '*.tif')) J = range((stop - start) // step) pool = mp.Pool(processes=mp.cpu_count()) exec_func = partial(exec_conc_mp, start, step, tmp[0], l, parameters, zfold, indir, ctdir) print(" - Concatenating") # start = time.time() pool.map(exec_func, J) # print "Images stitched in {:.01f} sec".format(time.time()-start) print("============ Done ============") #Clear temp directory clear_tmp(parameters) complete_message() ############################## HALF ACQ ############################## def stitch(first, second, axis, crop): h, w = first.shape if axis > w / 2: dx = int(2 * (w - axis) + 0.5) else: dx = int(2 * axis + 0.5) tmp = np.copy(first) first = second second = tmp result = np.empty((h, 2 * w - dx), dtype=first.dtype) ramp = np.linspace(0, 1, dx) # Mean values of the overlapping regions must match, which corrects flat-field inconsistency # between the two projections # We clip the values in second so that there are no saturated pixel overflow problems k = np.mean(first[:, w - dx:]) / np.mean(second[:, :dx]) second = np.clip(second * k, np.iinfo(np.uint16).min, np.iinfo(np.uint16).max).astype(np.uint16) result[:, :w - dx] = first[:, :w - dx] result[:, w - dx:w] = first[:, w - dx:] * (1 - ramp) + second[:, :dx] * ramp result[:, w:] = second[:, dx:] return result[:, slice(int(crop), int(2*(w - axis) - crop), 1)] ############################## HALF ACQ ############################## def stitch_float32_output(first, second, axis, crop): print(f"Stitching two halves with axis {axis}") h, w = first.shape if axis > w / 2: dx = int(2 * (w - axis) + 0.5) else: dx = int(2 * axis + 0.5) tmp = np.copy(first) first = second second = tmp result = np.empty((h, 2 * w - dx), dtype=first.dtype) ramp = np.linspace(0, 1, dx) # Mean values of the overlapping regions must match, which corrects flat-field inconsistency # between the two projections # We clip the values in second so that there are no saturated pixel overflow problems k = np.mean(first[:, w - dx:]) / np.mean(second[:, :dx]) result[:, :w - dx] = first[:, :w - dx] result[:, w - dx:w] = first[:, w - dx:] * (1 - ramp) + second[:, :dx] * ramp result[:, w:] = second[:, dx:] * k return result[:, slice(int(crop), int(2*(w - axis) - crop), 1)] def st_mp_idx(offst, ax, crop, in_fmt, out_fmt, idx): #we pass index and formats as argument first = read_image(in_fmt.format(idx)) second = read_image(in_fmt.format(idx+offst))[:, ::-1] stitched = stitch(first, second, ax, crop) tifffile.imwrite(out_fmt.format(idx), stitched) def main_360_mp_depth1(indir, outdir, ax, cro): if not os.path.exists(outdir): os.makedirs(outdir) subdirs = [dI for dI in os.listdir(indir) \ if os.path.isdir(os.path.join(indir, dI))] for i, sdir in enumerate(subdirs): print(f"Stitching images in {sdir}") names = sorted(glob.glob(os.path.join(indir, sdir, '*.tif'))) num_projs = len(names) if num_projs<2: warnings.warn("Warning: less than 2 files") print(str(num_projs) + ' files in ' + str(sdir)) os.makedirs(os.path.join(outdir, sdir)) out_fmt = os.path.join(outdir, sdir, 'sti-{:>04}.tif') # extraxt input file format firstfname = names[0] firstnum = re.match('.*?([0-9]+)$', firstfname[:-4]).group(1) n_dgts = len(firstnum) #number of significant digits idx0 = int(firstnum) trnc_len = n_dgts + 4 #format + .tif in_fmt = firstfname[:-trnc_len] + '{:0'+str(n_dgts)+'}.tif' offst = int(num_projs / 2) exec_func = partial(st_mp_idx, offst, ax, cro, in_fmt, out_fmt) idxs = range(idx0, idx0+offst) pool = mp.Pool(processes=mp.cpu_count()) # double check if names correspond - to remove later for nmi in idxs: #print(names[nmi-idx0], in_fmt.format(nmi)) if names[nmi-idx0] != in_fmt.format(nmi): print('Something wrong with file name format') continue #pool.map(exec_func, names[0:num_projs/2]) pool.map(exec_func, idxs) print("========== Done ==========") def main_360_mp_depth2(parameters): ctdirs, lvl0 = findCTdirs(parameters['360multi_input_dir'], glob_parameters.params['main_config_tomo_dir_name']) num_sets = len(ctdirs) if num_sets < 1: print(f"Didn't find any CT dirs in the input. Check directory structure and permissions. \n" f"Program expects to see a number of subdirectories in the input each of with \n" f"contains at least one directory with CT projections (currently name set to " f"{glob_parameters.params['main_config_tomo_dir_name']}. \n"+ f"The tif files in all " \ f" {glob_parameters.params['main_config_tomo_dir_name']}, " f" {glob_parameters.params['main_config_flats_dir_name']}, " f" {glob_parameters.params['main_config_darks_dir_name']} \n" f"subdirectories will be stitched to convert half-acquisition mode scans to ordinary \n" f"180-deg parallel-beam scans") return tmp = len(parameters['360multi_input_dir']) ctdirs_rel_paths = [] for i in range(num_sets): ctdirs_rel_paths.append(ctdirs[i][tmp + 1: len(ctdirs[i])]) print(f"Found the {num_sets} directories in the input with relative paths: {ctdirs_rel_paths}") # prepare axis and crop arrays dax = np.round(np.linspace(parameters['360multi_bottom_axis'], parameters['360multi_top_axis'], num_sets)) if parameters['360multi_manual_axis']: print(parameters['360multi_axis_dict']) dax = np.array(list(parameters['360multi_axis_dict'].values())) print(dax) cra = np.max(dax)-dax for i, ctdir in enumerate(ctdirs): print("================================================================") print(" -> Working On: " + str(ctdir)) print(f" axis position {dax[i]}, margin to crop {cra[i]} pixels") main_360_mp_depth1(ctdir, os.path.join(parameters['360multi_output_dir'], ctdirs_rel_paths[i]), dax[i], cra[i]) # print(ctdir, os.path.join(parameters['360multi_output_dir'], ctdirs_rel_paths[i]), dax[i], cra[i]) def clear_tmp(parameters): tmp_dirs = os.listdir(parameters['ezstitch_temp_dir']) for tmp_dir in tmp_dirs: shutil.rmtree(os.path.join(parameters['ezstitch_temp_dir'], tmp_dir)) def check_last_index(axis_list): """ Return the index of item in list immediately before first 'None' type :param axis_list: :return: the index of last non-None value """ last_index = 0 for index, item in enumerate(axis_list): if item == 'None': last_index = index - 1 return last_index last_index = index return last_index def complete_message(): print(" __.-/|") print(" \\`o_O'") print(" =( )= +-----+") print(" U| | FIN |") print(" /\\ /\\ / | +-----+") print(" ) /^\\) ^\\/ _)\\ |") print(" ) /^\\/ _) \\ |") print(" ) _ / / _) \\___|_") print(" /\\ )/\\/ || | )_)\\___,|))") print("< > |(,,) )__) |") print(" || / \\)___)\\") print(" | \\____( )___) )____") print(" \\______(_______;;;)__;;;)")tofu-0.12.0/tofu/ez/RR_external.py000066400000000000000000000150411423713721100167430ustar00rootroot00000000000000#!/usr/bin/env python2 """ Created on Aug 3, 2018 @author: SGasilov Initially it has been simplest median sorting Replaced by non-FFT based methods proposed by Nghia T. Vo and published in https://doi.org/10.1364/OE.26.028396 """ import os import argparse from tofu.util import read_image import numpy as np from tofu.util import get_filenames import multiprocessing as mp from functools import partial from scipy.ndimage import median_filter from scipy.ndimage import binary_dilation import tifffile def write_tiff(file_name, data): """ The default TIFF writer which uses :py:mod:`tifffile` module. Return the written file name. """ tifffile.imsave(file_name, data) return file_name def parse_args(): parser = argparse.ArgumentParser() parser.add_argument("--sinos", type=str, help="Input directory") parser.add_argument("--mws", type=int, help="Window size for small rings (sorting algorithm)") parser.add_argument("--mws2", type=int, help="Window size for large rings") parser.add_argument("--snr", type=int, help="Median window size along columns") parser.add_argument("--sort_only", type=int, help="Only sorting or both") return parser.parse_args() def RR_wide_sort(mws, mws2, snr, odir, fname): filt_sin_name = os.path.join(odir, os.path.split(fname)[1]) im = read_image(fname).astype(np.float32) im = remove_large_stripe(im, snr, mws2) im = remove_stripe_based_sorting(im, mws) write_tiff(filt_sin_name, im.astype(np.float32)) def RR_sort(mws, odir, fname): filt_sin_name = os.path.join(odir, os.path.split(fname)[1]) write_tiff( filt_sin_name, remove_stripe_based_sorting(read_image(fname).astype(np.float32), mws).astype(np.float32), ) def remove_stripe_based_sorting(sinogram, size, dim=1): # taken from sarepy, Author: Nghia T. Vo https://doi.org/10.1364/OE.26.028396 """ Remove stripe artifacts in a sinogram using the sorting technique, algorithm 3 in Ref. [1]. Angular direction is along the axis 0. Parameters ---------- sinogram : array_like 2D array. Sinogram image. size : int Window size of the median filter. dim : {1, 2}, optional Dimension of the window. """ sinogram = np.transpose(sinogram) (nrow, ncol) = sinogram.shape list_index = np.arange(0.0, ncol, 1.0) mat_index = np.tile(list_index, (nrow, 1)) mat_comb = np.asarray(np.dstack((mat_index, sinogram))) mat_sort = np.asarray([row[row[:, 1].argsort()] for row in mat_comb]) if dim == 2: mat_sort[:, :, 1] = median_filter(mat_sort[:, :, 1], (size, size)) else: mat_sort[:, :, 1] = median_filter(mat_sort[:, :, 1], (size, 1)) mat_sort_back = np.asarray([row[row[:, 0].argsort()] for row in mat_sort]) return np.transpose(mat_sort_back[:, :, 1]) def detect_stripe(list_data, snr): # taken from sarepy, Author: Nghia T. Vo https://doi.org/10.1364/OE.26.028396 """ Locate stripe positions using Algorithm 4 in Ref. [1]. Parameters ---------- list_data : array_like 1D array. Normalized data. snr : float Ratio used to segment stripes from background noise. """ npoint = len(list_data) list_sort = np.sort(list_data) listx = np.arange(0, npoint, 1.0) ndrop = np.int16(0.25 * npoint) (slope, intercept) = np.polyfit(listx[ndrop : -ndrop - 1], list_sort[ndrop : -ndrop - 1], 1) y_end = intercept + slope * listx[-1] noise_level = np.abs(y_end - intercept) noise_level = np.clip(noise_level, 1e-6, None) val1 = np.abs(list_sort[-1] - y_end) / noise_level val2 = np.abs(intercept - list_sort[0]) / noise_level list_mask = np.zeros(npoint, dtype=np.float32) if val1 >= snr: upper_thresh = y_end + noise_level * snr * 0.5 list_mask[list_data > upper_thresh] = 1.0 if val2 >= snr: lower_thresh = intercept - noise_level * snr * 0.5 list_mask[list_data <= lower_thresh] = 1.0 return list_mask def remove_large_stripe(sinogram, size, snr=3, drop_ratio=0.1, norm=True): # taken from sarepy, Author: Nghia T. Vo https://doi.org/10.1364/OE.26.028396 """ Remove large stripes, algorithm 5 in Ref. [1], by: locating stripes, normalizing to remove full stripes, and using the sorting technique (Ref. [1]) to remove partial stripes. Angular direction is along the axis 0. Parameters ---------- sinogram : array_like 2D array. Sinogram image snr : float Ratio used to segment stripes from background noise. size : int Window size of the median filter. drop_ratio : float, optional Ratio of pixels to be dropped, which is used to reduce the false detection of stripes. norm : bool, optional Apply normalization if True. """ sinogram = np.copy(sinogram) # Make it mutable drop_ratio = np.clip(drop_ratio, 0.0, 0.8) (nrow, ncol) = sinogram.shape ndrop = int(0.5 * drop_ratio * nrow) sino_sort = np.sort(sinogram, axis=0) sino_smooth = median_filter(sino_sort, (1, size)) list1 = np.mean(sino_sort[ndrop : nrow - ndrop], axis=0) list2 = np.mean(sino_smooth[ndrop : nrow - ndrop], axis=0) list_fact = np.divide(list1, list2, out=np.ones_like(list1), where=list2 != 0) list_mask = detect_stripe(list_fact, snr) list_mask = np.float32(binary_dilation(list_mask, iterations=1)) mat_fact = np.tile(list_fact, (nrow, 1)) if norm is True: sinogram = sinogram / mat_fact # Normalization sino_tran = np.transpose(sinogram) list_index = np.arange(0.0, nrow, 1.0) mat_index = np.tile(list_index, (ncol, 1)) mat_comb = np.asarray(np.dstack((mat_index, sino_tran))) mat_sort = np.asarray([row[row[:, 1].argsort()] for row in mat_comb]) mat_sort[:, :, 1] = np.transpose(sino_smooth) mat_sort_back = np.asarray([row[row[:, 0].argsort()] for row in mat_sort]) sino_cor = np.transpose(mat_sort_back[:, :, 1]) listx_miss = np.where(list_mask > 0.0)[0] sinogram[:, listx_miss] = sino_cor[:, listx_miss] return sinogram def main(): args = parse_args() sinos = get_filenames(os.path.join(args.sinos, "*.tif")) # create output directory wdir = os.path.split(args.sinos)[0] odir = os.path.join(wdir, "sinos-filt") if not os.path.exists(odir): os.makedirs(odir) pool = mp.Pool(processes=mp.cpu_count()) if args.sort_only: exec_func = partial(RR_sort, args.mws, odir) else: exec_func = partial(RR_wide_sort, args.mws, args.mws2, args.snr, odir) pool.map(exec_func, sinos) if __name__ == "__main__": main() tofu-0.12.0/tofu/ez/__init__.py000066400000000000000000000000011423713721100162430ustar00rootroot00000000000000 tofu-0.12.0/tofu/ez/ctdir_walker.py000066400000000000000000000160061423713721100171720ustar00rootroot00000000000000""" Created on Apr 5, 2018 @author: gasilos """ import os class WalkCTdirs: """ Walks in the directory structure and creates list of paths to CT folders Determines flats before/after and checks that folders contain only tiff files fdt_names = flats/darks/tomo directory names """ def __init__(self, inpath, fdt_names, args, verb=True): self.lvl0 = os.path.abspath(inpath) self.ctdirs = [] self.types = [] self.ctsets = [] self.typ = [] self.total = 0 self.good = 0 self.verb = verb self._fdt_names = fdt_names self.common_flats = args.main_config_flats_path self.common_darks = args.main_config_darks_path self.common_flats2 = args.main_config_flats2_path self.use_common_flats2 = args.main_config_flats2_checkbox def print_tree(self): print("We start in {}".format(self.lvl0)) def findCTdirs(self): """ Walks directories rooted at "Input Directory" location Appends their absolute path to ctdir if they contain a directory with same name as "tomo" entry in GUI """ for root, dirs, files in os.walk(self.lvl0): for name in dirs: if name == self._fdt_names[2]: self.ctdirs.append(root) self.ctdirs = list(set(self.ctdirs)) def checkCTdirs(self): """ Determine whether directory is of type 3 or type 4 and store in self.typ with index corresponding to ctdir Type3: Has flats, darks and not flats2 -- or flats==flats2 Type4: Has flats, darks and flats2 """ for ctdir in self.ctdirs: # flats/darks and no flats2 or flats2==flats -> type 3 if ( os.path.exists(os.path.join(ctdir, self._fdt_names[1])) and os.path.exists(os.path.join(ctdir, self._fdt_names[0])) and ( not os.path.exists(os.path.join(ctdir, self._fdt_names[3])) or self._fdt_names[1] == self._fdt_names[3] ) ): self.typ.append(3) # flats/darks/flats2 -> type4 elif ( os.path.exists(os.path.join(ctdir, self._fdt_names[1])) and os.path.exists(os.path.join(ctdir, self._fdt_names[0])) and os.path.exists(os.path.join(ctdir, self._fdt_names[3])) ): self.typ.append(4) else: print(os.path.basename(ctdir)) self.typ.append(0) def checkCommonFDT(self): """ Verifies that paths to directories specified by common_flats, common_darks, and common_flats2 exist :return: True if directories exist, False if they do not exist """ for ctdir in self.ctdirs: if self.use_common_flats2 is True: self.typ.append(4) elif self.use_common_flats2 is False: self.typ.append(3) if self.use_common_flats2 is True: if ( os.path.exists(self.common_flats) and os.path.exists(self.common_darks) and os.path.exists(self.common_flats2) ): return True elif self.use_common_flats2 is False: if (os.path.exists(self.common_flats) and os.path.exists(self.common_darks)): return True return False def checkCommonFDTFiles(self): """ Verifies that directories of tomo and common flats/darks/flats contain only .tif files :return: True if directories exist, False if they do not exist """ for i, ctdir in enumerate(self.ctdirs): ctdir_tomo_path = os.path.join(ctdir, self._fdt_names[2]) if not self._checkTifs(ctdir_tomo_path): print("Invalid files found in " + str(ctdir_tomo_path)) self.typ[i] = 0 return False if not self._checkTifs(self.common_flats): print("Invalid files found in " + str(self.common_flats)) return False if not self._checkTifs(self.common_darks): print("Invalid files found in " + str(self.common_darks)) return False if self.use_common_flats2 and not self._checkTifs(self.common_flats2): print("Invalid files found in " + str(self.common_flats2)) return False return True def checkCTfiles(self): """ Checks whether each ctdir is of type 3 or 4 by comparing index of self.typ[] to corresponding index of ctdir[] Then for each directory of type 3 or 4 it checks sub-directories contain only .tif files If it contains invalid data then typ[] is set to 0 for corresponding index location """ for i, ctdir in enumerate(self.ctdirs): if ( self.typ[i] == 3 and self._checkTifs(os.path.join(ctdir, self._fdt_names[1])) and self._checkTifs(os.path.join(ctdir, self._fdt_names[0])) and self._checkTifs(os.path.join(ctdir, self._fdt_names[2])) ): continue elif ( self.typ[i] == 4 and self._checkTifs(os.path.join(ctdir, self._fdt_names[1])) and self._checkTifs(os.path.join(ctdir, self._fdt_names[0])) and self._checkTifs(os.path.join(ctdir, self._fdt_names[2])) and self._checkTifs(os.path.join(ctdir, self._fdt_names[3])) ): continue else: self.typ[i] = 0 def _checkTifs(self, tmpath): """ Checks each whether item in directory tmppath is a .tif file :param tmpath: Path to directory :return: 0 if invalid item found in directory - 1 if no invalid items found in directory """ for i in os.listdir(tmpath): if os.path.isdir(i): return 0 if i.split(".")[-1] != "tif": return 0 return 1 def SortBadGoodSets(self): """ Reduces type of all directories to either Good with flats 2 (1) or good without flats2 (0) or bad (<0) """ self.total = len(self.ctdirs) self.ctsets = sorted(zip(self.ctdirs, self.typ), key=lambda s: s[0]) self.total = len(self.ctsets) self.good = [int(y) > 2 for x, y in self.ctsets].count(True) tmp = len(self.lvl0) if self.verb: print("Total folders {}, good folders {}".format(self.total, self.good)) print("{:>20}\t{}".format("Path to CT set", "Typ: 0 bad, 3 no flats2, 4 with flats2")) for ctdir in self.ctsets: msg1 = ctdir[0][tmp:] if msg1 == "": msg1 = "." print("{:>20}\t{}".format(msg1, ctdir[1])) # keep paths to directories with good ct data only: self.ctsets = [q for q in self.ctsets if int(q[1] > 0)] def Getlvl0(self): return self.lvl0 tofu-0.12.0/tofu/ez/evaluate_sharpness.py000066400000000000000000000317231423713721100204170ustar00rootroot00000000000000import argparse import glob import multiprocessing import os import time import numpy as np from functools import partial from tofu.util import read_image from scipy.stats import skew, kurtosis def sum_abs_gradient(data): """Sum of absolute gradients.""" return np.sum(np.abs(np.gradient(data))) def mad(data): """Median absolute deviation.""" return np.median(np.abs(data - np.median(data))) def abs_sum(data): """Sum of the absolute values.""" return np.sum(np.abs(data)) def entropy(data, bins=256): """Image entropy.""" hist, bins = np.histogram(data, bins=bins) hist = hist.astype(np.float) hist /= hist.sum() valid = np.where(hist > 0) return -np.sum(np.dot(hist[valid], np.log2(hist[valid]))) def inverted(func, *args, **kwargs): """Return -func(*args, **kwargs).""" return -func(*args, **kwargs) def filter_data(data, fwhm=32.0): """Filter low frequencies in 1D *data* (needed when the axis is far away by axis evaluation). *fwhm* is the FWHM of the gaussian used to filter out low frequencies in real space. The window is then computed as fft(1 - gauss). """ mean = np.mean(data) sigma = fwhm / (2 * np.sqrt(2 * np.log(2))) # We compute the gaussian in Fourier space, so convert sigma first f_sigma = 1.0 / (2 * np.pi * sigma) x = np.fft.fftfreq(len(data)) fltr = 1 - np.exp(-(x ** 2) / (2 * f_sigma ** 2)) return np.fft.ifft(np.fft.fft(data) * fltr).real + mean METRICS_1D = { "mean": np.mean, "std": np.std, "skew": skew, "kurtosis": kurtosis, "mad": mad, "asum": abs_sum, "min": np.min, "max": np.max, "entropy": entropy, } METRICS_2D = {"sag": sum_abs_gradient} for key in list(METRICS_1D): METRICS_1D["m" + key] = partial(inverted, METRICS_1D[key]) for key in list(METRICS_2D): METRICS_2D["m" + key] = partial(inverted, METRICS_2D[key]) # for key in METRICS_1D.keys(): # METRICS_1D['m' + key] = partial(inverted, METRICS_1D[key]) # for key in METRICS_2D.keys(): # METRICS_2D['m' + key] = partial(inverted, METRICS_2D[key]) def evaluate( image, metrics_1d=None, metrics_2d=None, global_min=None, global_max=None, metrics_1d_kwargs=None, blur_fwhm=None, ): """Evaluate *metrics_1d* which work on a flattened image and *metrics_2d* in an *image* which can either be a file path or an imageIf the metrics are None all the default ones are used. *global_min* and *global_max* are the mean extrema of the whole sequence used to cut off outlier values. Extrema are used only by 1d metrics. *metrics_1d_kwargs* are additional keyword arguments passed to the functions, they are specified in dictioinary {func_name: kwargs}. """ if metrics_1d is None: metrics_1d = METRICS_1D if metrics_2d is None: metrics_2d = METRICS_2D results = {} if type(image) == str: image = read_image(image) if blur_fwhm: from scipy.ndimage import gaussian_filter image = gaussian_filter(image, blur_fwhm / (2 * np.sqrt(2 * np.log(2)))) if global_min is None or global_max is None: flattened = image.flatten() else: # Use global cutoff flattened = image[np.where((image >= global_min) & (image <= global_max))] if metrics_1d is not None: for metric in metrics_1d: kwargs = {} if metrics_1d_kwargs and metric in metrics_1d_kwargs: kwargs = metrics_1d_kwargs[metric] results[metric] = metrics_1d[metric](flattened, **kwargs) if metrics_2d is not None: for metric in metrics_2d: results[metric] = metrics_2d[metric](image) return results def evaluate_metrics(images, out_prefix, *args, **kwargs): """Evaluate many *images* which are either file paths or images. *out_prefix* is the metric results file prefix. Metric names and file extension are appended to it. *args* and *kwargs* are passed to :func:`evaluate`. Except for *fwhm* in *kwargs* which is used to filter low frequencies from the results. """ fwhm = kwargs.pop("fwhm") if "fwhm" in kwargs else None pool = multiprocessing.Pool(processes=multiprocessing.cpu_count()) exec_func = partial(evaluate, *args, **kwargs) results = pool.map(exec_func, images) merged = {} for metric in results[0].keys(): merged[metric] = np.array([result[metric] for result in results]) if fwhm: # Filter out low frequencies merged[metric] = filter_data(merged[metric], fwhm=fwhm) if out_prefix is not None: path = out_prefix + "_" + metric + ".txt" np.savetxt(path, merged[metric], fmt="%g") return merged def process( names, num_images_for_stats=0, metric_names=None, out_prefix=None, fwhm=None, metrcs_1d_kwargs=None, blur_fwhm=None, ): """Process many files given by *names*. *out_prefix* is the output file prefix where the metric results will be written to. *fwhm* is used to filter our low frequencies from the results. *metrics_1d_kwargs* are additional keyword arguments passed to the functions, they are specified in dictioinary {func_name: kwargs}. """ if num_images_for_stats: if num_images_for_stats == -1: num_images_for_stats = len(names) extrema_metrics = {"min": np.min, "max": np.max} extrema = evaluate_metrics( names[:num_images_for_stats], None, metrics_1d=extrema_metrics, fwhm=fwhm, blur_fwhm=blur_fwhm, ) global_min = np.mean(extrema["min"]) global_max = np.mean(extrema["max"]) else: global_min = global_max = None metrics_1d, metrics_2d = make_metrics(metric_names) return evaluate_metrics( names, out_prefix, metrics_1d=metrics_1d, metrics_2d=metrics_2d, global_min=global_min, global_max=global_max, fwhm=fwhm, metrics_1d_kwargs=metrcs_1d_kwargs, blur_fwhm=blur_fwhm, ) def main(): args = parse_args() names = sorted(glob.glob(args.input)) if args.dims == 2: axis_length = int(np.sqrt(len(names))) size_str = "{} x {}".format(axis_length, axis_length) else: axis_length = len(names) size_str = str(axis_length) print("Data size: {}".format(size_str)) kwargs = {"entropy": {"bins": args.entropy_num_bins}} for key in kwargs.keys(): kwargs["m" + key] = kwargs[key] st = time.time() results = process( names, num_images_for_stats=args.num_images_for_stats, metric_names=args.metrics, fwhm=args.fwhm, metrcs_1d_kwargs=kwargs, blur_fwhm=args.blur_fwhm, ) if args.verbose: print("Duration: {} s".format(time.time() - st)) x_data = y_data = None for metric, data in results.iteritems(): if x_data is None: x_data = construct_range(args.x_from, args.x_to, len(data), unit=args.x_unit) y_data = construct_range(args.y_from, args.y_to, len(data), unit=args.y_unit) write( args.output, metric, data, axis_length, x_data=x_data, y_data=y_data, save_raw=args.save_raw, save_txt=args.save_txt, save_plot=args.save_plot, ) argmax = np.argmax(data) if args.dims == 2: argmax = np.unravel_index(argmax, (axis_length, axis_length)) y_argmax = y_data[argmax[0]].magnitude x_argmax = x_data[argmax[1]].magnitude retval = (x_argmax, y_argmax) else: x_argmax = x_data[argmax].magnitude retval = x_argmax print(retval) def write( out_dir, metric, data, axis_length, x_data=None, y_data=None, save_raw=False, save_txt=False, save_plot=False, ): out_path = os.path.join(out_dir, metric) if not os.path.exists(out_dir): os.makedirs(out_dir, mode=0o755) if axis_length == len(data): # 1D if save_raw: np.save(out_path + ".npy", data) if save_plot: write_1d_plot(out_path, data, metric, x_data=x_data) else: reshaped = data.reshape(axis_length, axis_length) if save_raw: write_libtiff(out_path + "_raw" + ".tif", reshaped.astype(np.float32)) if save_plot: write_2d_plot(out_path, reshaped, metric, x_data=x_data, y_data=y_data) if save_txt: data = np.array((x_data.magnitude, data)) # Convenient to be read by pgfplots np.savetxt(out_path + ".txt", data.T, fmt="%g", delimiter="\t", comments="", header="x\ty") def write_1d_plot(out_path, data, metric, x_data=None): from matplotlib import pyplot as plt plt.figure() if x_data is not None: plt.plot(x_data.magnitude, data) plt.xlabel(x_data.units) else: plt.plot(data) plt.title(metric) plt.grid() plt.savefig(out_path + ".tif") plt.close() def write_2d_plot(out_path, data, metric, x_data=None, y_data=None): from matplotlib import pyplot as plt, cm plt.figure() plt.imshow(data, cmap=cm.gray) if x_data is not None: x_from = x_data[0].magnitude x_to = x_data[-1].magnitude num_x_ticks = min(data.shape[1], 9) x_locs = np.linspace(-0.5, data.shape[1] - 0.5, num_x_ticks) x_labels = np.linspace(x_from, x_to, num_x_ticks) plt.xticks(x_locs, x_labels) plt.xlabel(x_data.units) if y_data is not None: y_from = y_data[0].magnitude y_to = y_data[-1].magnitude num_y_ticks = min(data.shape[0], 9) y_locs = np.linspace(-0.5, data.shape[0] - 0.5, num_y_ticks) y_labels = np.linspace(y_from, y_to, num_y_ticks) plt.yticks(y_locs, y_labels) plt.ylabel(y_data.units) plt.title(metric) plt.savefig(out_path + ".tif") plt.close() def construct_range(start, stop, num, unit=""): start = 0 if start is None else start stop = num if stop is None else stop region = np.linspace(start, stop, num=num, endpoint=False) return q.Quantity(region, unit) def make_metrics(keys): """Buld 1d and 2d metrics dictionaries from *keys*.""" if keys is None: metrics_1d = METRICS_1D metrics_2d = METRICS_2D else: metrics_1d = {key: METRICS_1D[key] for key in keys if key in METRICS_1D} metrics_2d = {key: METRICS_2D[key] for key in keys if key in METRICS_2D} return metrics_1d, metrics_2d def parse_args(): parser = argparse.ArgumentParser( description="Evaluate sharpness metrics based on parameter changes in 3D reconstruction" ) parser.add_argument("input", type=str, help="Input path pattern") parser.add_argument( "dims", type=int, choices=(1, 2), help="Number of scanned parameters in the data set" ) parser.add_argument("--output", type=str, default=".", help="Output directory") parser.add_argument( "--metrics", type=str, nargs="*", choices=METRICS_1D.keys() + METRICS_2D.keys(), help="Metrics to determine (m prefix means -metric)", ) parser.add_argument("--x-from", type=float, help="X data from") parser.add_argument("--x-to", type=float, help="X data to") parser.add_argument("--x-unit", type=str, default="", help="X axis units") parser.add_argument("--y-from", type=float, help="Y data from") parser.add_argument("--y-to", type=float, help="Y data to") parser.add_argument("--y-unit", type=str, default="", help="Y axis units") parser.add_argument( "--num-images-for-stats", type=int, default=0, help=( "If not zero, an " "image sequence is first read and the mean min and max intensities are " "used as a global range of values to work on (-1 means read all images)" ), ) parser.add_argument( "--fwhm", type=float, help="FWHM of 1 - Gauss in real space used to filter out low frequencies.", ) parser.add_argument( "--entropy-num-bins", type=int, default=256, help="Number of bins to use for histogram calculation by entropy", ) parser.add_argument( "--blur-fwhm", type=float, help="FWHM of the Gaussian blur applied to images" ) parser.add_argument("--save-raw", action="store_true", help="Store raw data (1D npy, 2D tiff)?") parser.add_argument("--save-txt", action="store_true", help="Store raw data as text files") parser.add_argument("--save-plot", action="store_true", help="Store plot data") parser.add_argument("--verbose", action="store_true", help="Verbose output") args = parser.parse_args() if (args.x_from is None) ^ (args.x_to is None): raise ValueError("Either both x-from and x-to are set or both are not") if (args.y_from is None) ^ (args.y_to is None): raise ValueError("Either both y-from and y-to are set or both are not") return args if __name__ == "__main__": main() tofu-0.12.0/tofu/ez/find_axis_cmd_gen.py000066400000000000000000000141311423713721100201350ustar00rootroot00000000000000#!/bin/python """ Created on Apr 6, 2018 @author: gasilos """ import glob import os import numpy as np from tofu.ez.evaluate_sharpness import process as process_metrics from tofu.ez.util import enquote from tofu.util import get_filenames, read_image, determine_shape import tifffile class findCOR_cmds(object): """ Generates commands to find the axis of rotation """ def __init__(self, fol): self._fdt_names = fol def make_inpaths(self, lvl0, flats2, args): """ Creates a list of paths to flats/darks/tomo directories :param lvl0: Root of directory containing flats/darks/tomo :param flats2: The type of directory: 3 contains flats/darks/tomo 4 contains flats/darks/tomo/flats2 :return: List of paths to the directories containing darks/flats/tomo and flats2 (if used) """ indir = [] # If using flats/darks/flats2 in same dir as tomo if not args.main_config_common_flats_darks: for i in self._fdt_names[:3]: indir.append(os.path.join(lvl0, i)) if flats2 - 3: indir.append(os.path.join(lvl0, self._fdt_names[3])) return indir # If using common flats/darks/flats2 across multiple reconstructions elif args.main_config_common_flats_darks: indir.append(args.main_config_darks_path) indir.append(args.main_config_flats_path) indir.append(os.path.join(lvl0, self._fdt_names[2])) if args.main_config_flats2_checkbox: indir.append(args.main_config_flats2_path) return indir def find_axis_std(self, ctset, tmpdir, ax_range, p_width, search_row, nviews, args, WH): indir = self.make_inpaths(ctset[0], ctset[1], args) image = read_image(get_filenames(indir[2])[0]) cmd = "tofu reco --absorptivity --fix-nan-and-inf --overall-angle 180 --axis-angle-x 0" cmd += " --darks {} --flats {} --projections {}".format( indir[0], indir[1], enquote(indir[2]) ) cmd += " --number {}".format(nviews) if ctset[1] == 4: cmd += " --flats2 {}".format(indir[3]) out_pattern = os.path.join(tmpdir, "axis-search/sli") cmd += " --output {}".format(enquote(out_pattern)) cmd += " --x-region={},{},{}".format(int(-p_width / 2), int(p_width / 2), 1) cmd += " --y-region={},{},{}".format(int(-p_width / 2), int(p_width / 2), 1) image_height = WH[0] ax_range_list = ax_range.split(",") range_min = ax_range_list[0] range_max = ax_range_list[1] step = ax_range_list[2] range_string = str(range_min) + "," + str(range_max) + "," + str(step) cmd += " --region={}".format(range_string) res = [float(num) for num in ax_range.split(",")] cmd += " --output-bytes-per-file 0" cmd += ' --z-parameter center-position-x' cmd += ' --z {}'.format(args.main_cor_search_row_start - int(image_height/2)) print(cmd) os.system(cmd) points, maximum = evaluate_images_simp(out_pattern + "*.tif", "msag") return res[0] + res[2] * maximum def find_axis_corr(self, ctset, vcrop, y, height, multipage, args): indir = self.make_inpaths(ctset[0], ctset[1], args) """Use correlation to estimate center of rotation for tomography.""" from scipy.signal import fftconvolve def flat_correct(flat, radio): nonzero = np.where(radio != 0) result = np.zeros_like(radio) result[nonzero] = flat[nonzero] / radio[nonzero] # log(1) = 0 result[result <= 0] = 1 return np.log(result) if multipage: with tifffile.TiffFile(get_filenames(indir[2])[0]) as tif: first = tif.pages[0].asarray().astype(np.float) with tifffile.TiffFile(get_filenames(indir[2])[-1]) as tif: last = tif.pages[-1].asarray().astype(np.float) with tifffile.TiffFile(get_filenames(indir[0])[-1]) as tif: dark = tif.pages[-1].asarray().astype(np.float) with tifffile.TiffFile(get_filenames(indir[1])[0]) as tif: flat1 = tif.pages[-1].asarray().astype(np.float) - dark else: first = read_image(get_filenames(indir[2])[0]).astype(np.float) last = read_image(get_filenames(indir[2])[-1]).astype(np.float) dark = read_image(get_filenames(indir[0])[-1]).astype(np.float) flat1 = read_image(get_filenames(indir[1])[-1]) - dark first = flat_correct(flat1, first - dark) if ctset[1] == 4: if multipage: with tifffile.TiffFile(get_filenames(indir[3])[0]) as tif: flat2 = tif.pages[-1].asarray().astype(np.float) - dark else: flat2 = read_image(get_filenames(indir[3])[-1]) - dark last = flat_correct(flat2, last - dark) else: last = flat_correct(flat1, last - dark) if vcrop: y_region = slice(y, min(y + height, first.shape[0]), 1) first = first[y_region, :] last = last[y_region, :] width = first.shape[1] first = first - first.mean() last = last - last.mean() conv = fftconvolve(first, last[::-1, :], mode="same") center = np.unravel_index(conv.argmax(), conv.shape)[1] return (width / 2.0 + center) / 2.0 # Find midpoint width of image and return its value def find_axis_image_midpoint(self, ctset, multipage, height_width): return height_width[1] // 2 def evaluate_images_simp( input_pattern, metric, num_images_for_stats=0, out_prefix=None, fwhm=None, blur_fwhm=None, verbose=False, ): # simplified version of original evaluate_images function # from Tomas's optimize_parameters script names = sorted(glob.glob(input_pattern)) res = process_metrics( names, num_images_for_stats=num_images_for_stats, metric_names=(metric,), out_prefix=out_prefix, fwhm=fwhm, blur_fwhm=blur_fwhm, )[metric] return res, np.argmax(res) tofu-0.12.0/tofu/ez/image_read_write.py000066400000000000000000000163341423713721100200130ustar00rootroot00000000000000import os, glob import numpy as np import tifffile from tifffile import imread, imwrite class InvalidDataSetError(Exception): """ Error to be raised on attempt to read data from empty or non-existing data set """ def validate_files_path(files_path: str, supported_file_types: list) -> bool: """ Validates specified path :param supported_file_types: List of supported extensions :param files_path: Path to validate :return: True if path exists and contains at least one file of supported type, else False """ try: valid_files_list = get_valid_files_list( files_path=files_path, supported_file_types=supported_file_types ) except InvalidDataSetError: return False return len(valid_files_list) > 0 def get_valid_files_list(files_path: str, supported_file_types: list) -> list: """ Get the list of files of supported type in directory :param supported_file_types: List of supported extensions :param files_path: Path to directory with files :return: List of full paths to files """ # Check if directory exists if not os.path.exists(files_path): raise InvalidDataSetError(f"No such directory: {files_path}") files_list = os.listdir(files_path) valid_files_list = [ os.path.join(files_path, file_name) for file_name in files_list if os.path.splitext(file_name)[1] in supported_file_types ] return sorted(valid_files_list) def read_image(image_file_path: str, data_type=np.float32) -> np.ndarray: """ Reads image file to numpy.ndarray of specified type :param data_type: Data type to store the image :param image_file_path: Full path to image to read :return: """ return imread(image_file_path).astype(dtype=data_type) def write_image(image: np.ndarray, target_directory: str, target_name: str, data_type=np.float32): """ Writes image data to file :param image: Image data :param target_directory: Path to directory to write image :param target_name: Target image file name :param data_type: Data type to write the image :return: """ os.makedirs(target_directory, exist_ok=True) data_file_path = os.path.join(target_directory, target_name) imwrite(data_file_path, data=image.astype(dtype=data_type)) def write_all_images(tiff_arr: np.ndarray, target_directory: str, data_type=np.float32): """ Writes all images in numpy array as individual files in a directory :param tiff_arr: Array containing images :param target_directory: Path to directory to write images :param data_type: Data type to write the images :return: """ print("Writing Images to Directory") # We determine the number of leading zeros to append. # Find number of digits from number of files to write, then add +1 number of leading zeros index = 1 length_str = str(tiff_arr.shape[0]) num_digits = len(length_str) for image in tiff_arr: write_image( image, target_directory, "Image_" + str(index).zfill(num_digits + 1) + ".tif", data_type ) index += 1 print("Finished Writing Images to Directory") def read_all_images( image_files_path: str, supported_image_types: list, data_type=np.float32 ) -> np.ndarray: """ Reads all images of the supported type from specified directory :param supported_image_types: List of supported extensions :param image_files_path: Path to directory with images :param data_type: Data type to store the images :return: 3-dimensional numpy.ndarray of specified type, first index being image index """ valid_files_list = get_valid_files_list( files_path=image_files_path, supported_file_types=supported_image_types ) if len(valid_files_list) == 0: raise InvalidDataSetError( f"Directory {image_files_path} " f"does not contain files of supported types {supported_image_types}" ) data_array = imread(valid_files_list).astype(dtype=data_type) return np.array(data_array) """Image readers for convenient work with multi-page image sequences.""" """ TAKEN STRAIT FROM ufo-kit/Concert (python 2 version) with permission of Tomas""" class FileSequenceReader(object): """Image sequence reader optimized for reading consecutive images. One multi-page image file is not closed after an image is read so that it does not have to be re-opened for reading the next image. The :func:`.close` function must be called explicitly in order to close the last opened image. """ def __init__(self, file_prefix, ext=''): if os.path.isdir(file_prefix): file_prefix = os.path.join(file_prefix, '*' + ext) self._filenames = sorted(glob.glob(file_prefix)) if not self._filenames: raise SequenceReaderError("No files matching `{}' found".format(file_prefix)) self._lengths = {} self._file = None self._filename = None def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): self.close() @property def num_images(self): num = 0 for filename in self._filenames: num += self._get_num_images_in_file(filename) return num def read(self, index): if index < 0: # Enables negative indexing index += self.num_images file_index = 0 while index >= 0: if file_index >= len(self._filenames): raise SequenceReaderError('image index greater than sequence length') index -= self._get_num_images_in_file(self._filenames[file_index]) file_index += 1 file_index -= 1 index += self._lengths[self._filenames[file_index]] self._open(self._filenames[file_index]) return self._read_real(index) def _open(self, filename): if self._filename != filename: if self._filename: self.close() self._file = self._open_real(filename) self._filename = filename def close(self): if self._filename: self._close_real() self._file = None self._filename = None def _get_num_images_in_file(self, filename): if filename not in self._lengths: self._open(filename) self._lengths[filename] = self._get_num_images_in_file_real() return self._lengths[filename] def _open_real(self, filename): """Returns an open file.""" raise NotImplementedError def _close_real(self, filename): raise NotImplementedError def _get_num_images_in_file_real(self): raise NotImplementedError def _read_real(self, index): raise NotImplementedError class TiffSequenceReader(FileSequenceReader): def __init__(self, file_prefix, ext='.tif'): super(TiffSequenceReader, self).__init__(file_prefix, ext=ext) def _open_real(self, filename): import tifffile return tifffile.TiffFile(filename) def _close_real(self): self._file.close() def _get_num_images_in_file_real(self): return len(self._file.pages) def _read_real(self, index): return self._file.pages[index].asarray() class SequenceReaderError(Exception): passtofu-0.12.0/tofu/ez/main.py000066400000000000000000000403611423713721100154450ustar00rootroot00000000000000""" Created on Apr 5, 2018 @author: sergei gasilov """ import logging import os from tofu.util import get_filenames, read_image import warnings warnings.filterwarnings("ignore") import time #from shutil import rmtree from tofu.ez.ctdir_walker import WalkCTdirs from tofu.ez.tofu_cmd_gen import tofu_cmds from tofu.ez.ufo_cmd_gen import ufo_cmds from tofu.ez.find_axis_cmd_gen import findCOR_cmds from tofu.ez.util import * # from tofu.util import get_filenames LOG = logging.getLogger(__name__) def get_CTdirs_list(inpath, fdt_names, args): """ Determines whether directories containing CT data are valid. Returns list of subdirectories with valid CT data :param inpath: Path to the CT directory containing subdirectories with flats/darks/tomo (and flats2 if used) :param fdt_names: Names of the directories which store flats/darks/tomo (and flats2 if used) :param args: Arguments from the GUI :return: W.ctsets: List of "good" CTSets and W.lvl0: Path to root of CT sets """ # Constructor call to create WalkCTDirs object W = WalkCTdirs(inpath, fdt_names, args) # Find any directories containing "tomo" directory W.findCTdirs() # If "Use common flats/darks across multiple experiments" is enabled if args.main_config_common_flats_darks: logging.debug("Use common darks/flats") logging.debug("Path to darks: " + str(args.main_config_darks_path)) logging.debug("Path to flats: " + str(args.main_config_flats_path)) logging.debug("Path to flats2: " + str(args.main_config_flats2_path)) logging.debug("Use flats2: " + str(args.main_config_flats2_checkbox)) # Determine whether paths to common flats/darks/flats2 exist if not W.checkCommonFDT(): print("Invalid path to common flats/darks") return W.ctsets, W.lvl0 else: LOG.debug("Paths to common flats/darks exist") # Check whether directories contain only .tif files if not W.checkCommonFDTFiles(): return W.ctsets, W.lvl0 else: # Sort good bad sets W.SortBadGoodSets() return W.ctsets, W.lvl0 # If "Use common flats/darks across multiple experiments" is not enabled else: LOG.debug("Use flats/darks in same directory as tomo") # Check if common flats/darks/flats2 are type 3 or 4 W.checkCTdirs() # Need to check if common flats/darks contain only .tif files W.checkCTfiles() W.SortBadGoodSets() return W.ctsets, W.lvl0 def frmt_ufo_cmds(cmds, ctset, out_pattern, ax, args, Tofu, Ufo, FindCOR, nviews, WH): """formats list of processing commands for a CT set""" # two helper variables to mark that PR/FFC has been done at some step swiFFC = True # FFC is always required required swiPR = args.main_pr_phase_retrieval # PR is an optional operation ####### PREPROCESSING ######### flat_file_for_mask = os.path.join(args.main_config_temp_dir, 'flat.tif') if args.main_filters_remove_spots: if not args.main_config_common_flats_darks: flatdir = os.path.join(ctset[0], Tofu._fdt_names[1]) elif args.main_config_common_flats_darks: flatdir = args.main_config_flats_path cmd = make_copy_of_flat(flatdir, flat_file_for_mask, args.main_config_dry_run) cmds.append(cmd) if args.main_config_preprocess: cmds.append('echo " - Applying filter(s) to images "') cmds_prepro = Ufo.get_pre_cmd(ctset, args.main_config_preprocess_command, args.main_config_temp_dir, args.main_config_dry_run, args) cmds.extend(cmds_prepro) # reset location of input data ctset = (args.main_config_temp_dir, ctset[1]) ################################################### if args.main_filters_remove_spots: # generate commands to remove sci. spots from projections cmds.append('echo " - Flat-correcting and removing large spots"') cmds_inpaint = Ufo.get_inp_cmd(ctset, args.main_config_temp_dir, args, WH[0], nviews, flat_file_for_mask) # reset location of input data ctset = (args.main_config_temp_dir, ctset[1]) cmds.extend(cmds_inpaint) swiFFC = False # no need to do FFC anymore ######## PHASE-RETRIEVAL ####### # Do PR separately if sinograms must be generate or if vertical ROI is defined if args.main_pr_phase_retrieval and args.main_filters_ring_removal: # or (args.main_pr_phase_retrieval and args.main_region_select_rows): if swiFFC: # we still need need flat correction #Inpaint No cmds.append('echo " - Phase retrieval with flat-correction"') if args.advanced_ffc_sinFFC: cmds.append(Tofu.get_pr_sinFFC_cmd(ctset, args, nviews, WH[0])) cmds.append(Tofu.get_pr_tofu_cmd_sinFFC(ctset, args, nviews, WH)) elif not args.advanced_ffc_sinFFC: cmds.append(Tofu.get_pr_tofu_cmd(ctset, args, nviews, WH[0])) else: # Inpaint Yes cmds.append('echo " - Phase retrieval from flat-corrected projections"') cmds.extend(Ufo.get_pr_ufo_cmd(args, nviews, WH)) swiPR = False # no need to do PR anymore swiFFC = False # no need to do FFC anymore # if args.PR and args.vcrop: # have to reset location of input data # ctset = (args.tmpdir, ctset[1]) ################# RING REMOVAL ####################### if args.main_filters_ring_removal: # Generate sinograms first if swiFFC: # we still need to do flat-field correction if args.advanced_ffc_sinFFC: # Create flat corrected images using sinFFC cmds.append(Tofu.get_sinFFC_cmd(ctset, args, nviews, WH[0])) # Feed the flat corrected images to sino gram generation cmds.append(Tofu.get_sinos_noffc_cmd(ctset[0], args.main_config_temp_dir, args, nviews, WH)) elif not args.advanced_ffc_sinFFC: cmds.append('echo " - Make sinograms with flat-correction"') cmds.append(Tofu.get_sinos_ffc_cmd(ctset, args.main_config_temp_dir, args, nviews, WH)) else: # we do not need flat-field correction cmds.append('echo " - Make sinograms without flat-correction"') cmds.append(Tofu.get_sinos_noffc_cmd(ctset[0], args.main_config_temp_dir, args, nviews, WH)) swiFFC = False # Filter sinograms if args.main_filters_ring_removal_ufo_lpf: if args.main_filters_ring_removal_ufo_lpf_1d_or_2d: cmds.append('echo " - Ring removal - ufo 1d stripes filter"') cmds.append(Ufo.get_filter1d_sinos_cmd(args.main_config_temp_dir, args.main_filters_ring_removal_ufo_lpf_sigma_horizontal, nviews)) else: cmds.append('echo " - Ring removal - ufo 2d stripes filter"') cmds.append(Ufo.get_filter2d_sinos_cmd(args.main_config_temp_dir, \ args.main_filters_ring_removal_ufo_lpf_sigma_horizontal, args.main_filters_ring_removal_ufo_lpf_sigma_vertical, nviews, WH[1])) else: cmds.append('echo " - Ring removal - sarepy filter(s)"') # note - calling an external program, not an ufo-kit script tmp = os.path.dirname(os.path.abspath(__file__)) path_to_filt = os.path.join(tmp, "RR_external.py") if os.path.isfile(path_to_filt): tmp = os.path.join(args.main_config_temp_dir, "sinos") cmdtmp = 'python {} --sinos {} --mws {} --mws2 {} --snr {} --sort_only {}' \ .format(path_to_filt, tmp, args.main_filters_ring_removal_sarepy_window_size, args.main_filters_ring_removal_sarepy_window, args.main_filters_ring_removal_sarepy_SNR, int(not args.main_filters_ring_removal_sarepy_wide)) cmds.append(cmdtmp) else: cmds.append('echo "Omitting RR because file with filter does not exist"') if not args.main_config_keep_temp: cmds.append("rm -rf {}".format(os.path.join(args.main_config_temp_dir, "sinos"))) # Convert filtered sinograms back to projections cmds.append('echo " - Generating proj from filtered sinograms"') cmds.append(Tofu.get_sinos2proj_cmd(args, WH[0])) # reset location of input data ctset = (args.main_config_temp_dir, ctset[1]) # Finally - call to tofu reco cmds.append('echo " - CT with axis {}; ffc:{}, PR:{}"'.format(ax, swiFFC, swiPR)) if args.advanced_ffc_sinFFC and swiFFC: cmds.append(Tofu.get_sinFFC_cmd(ctset, args, nviews, WH[0])) cmds.append( Tofu.get_reco_cmd_sinFFC(ctset, out_pattern, ax, args, nviews, WH, swiFFC, swiPR) ) else: # If not using sinFFC cmds.append(Tofu.get_reco_cmd(ctset, out_pattern, ax, args, nviews, WH, swiFFC, swiPR)) return nviews, WH def fmt_nlmdn_ufo_cmd(inpath: str, outpath: str, args): """ :param inp: Path to input directory before NLMDN applied :param out: Path to output directory after NLMDN applied :param args: List of args :return: """ cmd = 'ufo-launch read path={}'.format(inpath) cmd += ' ! non-local-means patch-radius={}'.format(args.advanced_nlmdn_patch_radius) cmd += ' search-radius={}'.format(args.advanced_nlmdn_sim_search_radius) cmd += ' h={}'.format(args.advanced_nlmdn_smoothing_control) cmd += ' sigma={}'.format(args.advanced_nlmdn_noise_std) cmd += ' window={}'.format(args.advanced_nlmdn_window) cmd += ' fast={}'.format(args.advanced_nlmdn_fast) cmd += ' estimate-sigma={}'.format(args.advanced_nlmdn_estimate_sigma) cmd += ' ! write filename={}'.format(enquote(outpath)) if not args.advanced_nlmdn_save_bigtiff: cmd += " bytes-per-file=0 tiff-bigtiff=False" return cmd def execute_reconstruction(args, fdt_names): # array with the list of commands cmds = [] # clean temporary directory or create if it doesn't exist if not os.path.exists(args.main_config_temp_dir): os.makedirs(args.main_config_temp_dir) # else: # clean_tmp_dirs(args.main_config_temp_dir, fdt_names) if args.main_region_clip_histogram: if args.main_region_histogram_min > args.main_region_histogram_max: raise ValueError('hmin must be smaller than hmax to convert to 8bit without contrast inversion') # get list of all good CT directories to be reconstructed print('*********** Analyzing input directory ************') W, lvl0 = get_CTdirs_list(args.main_config_input_dir, fdt_names, args) # W is an array of tuples (path, type) # get list of already reconstructed sets recd_sets = findSlicesDirs(args.main_config_output_dir) # initialize command generators FindCOR = findCOR_cmds(fdt_names) Tofu = tofu_cmds(fdt_names) Ufo = ufo_cmds(fdt_names) # populate list of reconstruction commands print("*********** AXIS INFO ************") for i, ctset in enumerate(W): # ctset is a tuple containing a path and a type (3 or 4) if not already_recd(ctset[0], lvl0, recd_sets): # determine initial number of projections and their shape path2proj = os.path.join(ctset[0], fdt_names[2]) nviews, WH, multipage = get_dims(path2proj) # If args.main_cor_axis_search_method == 4 then bypass axis search and use image midpoint if args.main_cor_axis_search_method != 4: if (args.main_region_select_rows and bad_vert_ROI(multipage, path2proj, args.main_region_first_row, args.main_region_number_rows)): print('{}\t{}'.format('CTset:', ctset[0])) print('{:>30}\t{}'.format('Axis:', 'na')) print('Vertical ROI does not contain any rows.') print("{:>30}\t{}, dimensions: {}".format("Number of projections:", nviews, WH)) continue # Find axis of rotation using auto: correlate first/last projections if args.main_cor_axis_search_method == 1: ax = FindCOR.find_axis_corr(ctset, args.main_region_select_rows, args.main_region_first_row, args.main_region_number_rows, multipage, args) # Find axis of rotation using auto: minimize STD of a slice elif args.main_cor_axis_search_method == 2: cmds.append("echo \"Cleaning axis-search in tmp directory\"") os.system('rm -rf {}'.format(os.path.join(args.main_config_temp_dir, 'axis-search'))) ax = FindCOR.find_axis_std(ctset, args.main_config_temp_dir, args.main_cor_axis_search_interval, args.main_cor_recon_patch_size, args.main_cor_search_row_start, nviews, args, WH) else: ax = args.main_cor_axis_column + i * args.main_cor_axis_increment_step # If args.main_cor_axis_search_method == 4 then bypass axis search and use image midpoint elif args.main_cor_axis_search_method == 4: ax = FindCOR.find_axis_image_midpoint(ctset, multipage, WH) print("Bypassing axis search and using image midpoint: {}".format(ax)) setid = ctset[0][len(lvl0) + 1:] out_pattern = os.path.join(args.main_config_output_dir, setid, 'sli/sli') cmds.append('echo ">>>>> PROCESSING {}"'.format(setid)) # rm files in temporary directory first of all to # format paths correctly and to avoid problems # when reconstructing ct sets with variable number of rows or projections cmds.append('echo "Cleaning temporary directory"'.format(setid)) clean_tmp_dirs(args.main_config_temp_dir, fdt_names) # call function which formats commands for this data set nviews, WH = frmt_ufo_cmds(cmds, ctset, out_pattern, \ ax, args, Tofu, Ufo, FindCOR, nviews, WH) save_params(args, setid, ax, nviews, WH) print('{}\t{}'.format('CTset:', ctset[0])) print('{:>30}\t{}'.format('Axis:', ax)) print("{:>30}\t{}, dimensions: {}".format("Number of projections:", nviews, WH)) # tmp = "Number of projections: {}, dimensions: {}".format(nviews, WH) # cmds.append("echo \"{}\"".format(tmp)) if args.advanced_nlmdn_apply_after_reco: logging.debug("Using Non-Local Means Denoising") nlmdn_input = out_pattern head, tail = os.path.split(out_pattern) slidir = os.path.dirname(head) nlmdn_output = os.path.join(slidir+"-nlmdn", "sli-nlmdn-%04i.tif") cmds.append(fmt_nlmdn_ufo_cmd(slidir, nlmdn_output, args)) else: print("{} has been already reconstructed".format(ctset[0])) # execute commands = start reconstruction start = time.time() print("*********** PROCESSING ************") for cmd in cmds: if not args.main_config_dry_run: os.system(cmd) else: print(cmd) if not args.main_config_keep_temp: clean_tmp_dirs(args.main_config_temp_dir, fdt_names) print("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx") print("*** Done. Total processing time {} sec.".format(int(time.time() - start))) print("*** Waiting for the next job...........") # cmnds, axes = get_ufo_cmnds(W, tmpdir, recodir, fol, axes = None, dryrun = False) def already_recd(ctset, indir, recd_sets): x = False if ctset[len(indir) + 1 :] in recd_sets: x = True return x def findSlicesDirs(lvl0): recd_sets = [] for root, dirs, files in os.walk(lvl0): for name in dirs: if name == "sli": recd_sets.append(root[len(lvl0) + 1 :]) return recd_sets tofu-0.12.0/tofu/ez/main_nlm.py000066400000000000000000000023371423713721100163140ustar00rootroot00000000000000""" Created on Dec 1, 2020 @author: sergei gasilov """ import os import warnings from tofu.ez.ctdir_walker import WalkCTdirs from tofu.ez.util import * from tofu.util import get_filenames from tofu.util import get_filenames, read_image from tofu.ez.util import enquote warnings.filterwarnings("ignore") def fmt_ufo_cmd(inp, out, args): cmd = "ufo-launch read path={}".format(inp) cmd += " ! non-local-means patch-radius={}".format(args.patch_r) cmd += " search-radius={}".format(args.search_r) cmd += " h={}".format(args.h) cmd += " sigma={}".format(args.sig) cmd += " window={}".format(args.w) cmd += " fast={}".format(args.fast) cmd += " estimate-sigma={}".format(args.autosig) cmd += " ! write filename={}".format(enquote(out)) if not args.bigtif: cmd += " bytes-per-file=0 tiff-bigtiff=False" return cmd def main_tk(args): if args.input_is_file: out_pattern = args.outdir else: if not os.path.exists(args.outdir): os.makedirs(args.outdir) out_pattern = os.path.join(args.outdir, "im-nlmfilt-%05i.tif") cmd = fmt_ufo_cmd(args.indir, out_pattern, args) if args.dryrun: print(cmd) else: os.system(cmd) return 0 tofu-0.12.0/tofu/ez/params.py000066400000000000000000000005621423713721100160030ustar00rootroot00000000000000# This file is used to share params as a global variable import yaml #TODO Make good structure to store parameters # similar to tofu? and # use tofu's structure for existing reco params params = {} def save_parameters(params, file_path): file_out = open(file_path, 'w') yaml.dump(params, file_out) print("Parameters file saved at: " + str(file_path)) tofu-0.12.0/tofu/ez/tofu_cmd_gen.py000066400000000000000000000537411423713721100171600ustar00rootroot00000000000000#!/bin/python """ Created on Apr 6, 2018 @author: gasilos """ import os import numpy as np from tofu.ez.ufo_cmd_gen import fmt_in_out_path class tofu_cmds(object): """ Generates partially formatted ufo-launch and tofu commands Parameters are included in the string; pathnames must be added """ def __init__(self, fol): self._fdt_names = fol def make_inpaths(self, lvl0, flats2, args): """ Creates a list of paths to flats/darks/tomo directories :param lvl0: Root of directory containing flats/darks/tomo :param flats2: The type of directory: 3 contains flats/darks/tomo 4 contains flats/darks/tomo/flats2 :return: List of paths to the directories containing darks/flats/tomo and flats2 (if used) """ indir = [] # If using flats/darks/flats2 in same dir as tomo if not args.main_config_common_flats_darks: for i in self._fdt_names[:3]: indir.append(os.path.join(lvl0, i)) if flats2 - 3: indir.append(os.path.join(lvl0, self._fdt_names[3])) return indir # If using common flats/darks/flats2 across multiple reconstructions elif args.main_config_common_flats_darks: indir.append(args.main_config_darks_path) indir.append(args.main_config_flats_path) indir.append(os.path.join(lvl0, self._fdt_names[2])) if args.main_config_flats2_checkbox: indir.append(args.main_config_flats2_path) return indir def check_lamino(self, cmd, args): cmd += 'tofu reco' if not args.advanced_advtofu_lamino_angle == '': cmd += ' --axis-angle-x {}'.format(args.advanced_advtofu_lamino_angle) if not args.advanced_adv_tofu_z_axis_rotation == '': cmd += ' --overall-angle {}'.format(args.advanced_adv_tofu_z_axis_rotation) if not args.advanced_advtofu_center_position_z == '': cmd += ' --center-position-z {}'.format(args.advanced_advtofu_center_position_z) if not args.advanced_advtofu_y_axis_rotation == '': cmd += ' --axis-angle-y {}'.format(args.advanced_advtofu_y_axis_rotation) return cmd def check_8bit(self, cmd, gray256, bit, hmin, hmax): if gray256: cmd += " --output-bitdepth {}".format(bit) # cmd += " --output-minimum \" {}\" --output-maximum \" {}\""\ # .format(hmin, hmax) cmd += ' --output-minimum " {}" --output-maximum " {}"'.format(hmin, hmax) return cmd def check_vcrop(self, cmd, vcrop, y, yheight, ystep, ori_height): if vcrop: cmd += " --y {} --height {} --y-step {}".format(y, yheight, ystep) else: cmd += " --height {}".format(ori_height) return cmd def check_bigtif(self, cmd, swi): if not swi: cmd += " --output-bytes-per-file 0" return cmd def get_1step_ct_cmd(self, ctset, out_pattern, ax, args, nviews, WH): # direct CT reconstruction from input dir to output dir; # or CT reconstruction after preprocessing only indir = self.make_inpaths(ctset[0], ctset[1], args) # correct location of proj folder in case if prepro was done in_proj_dir, quatsch = fmt_in_out_path(args.main_config_temp_dir, ctset[0], self._fdt_names[2], False) indir[2] = os.path.join(os.path.split(indir[2])[0], os.path.split(in_proj_dir)[1]) # format command cmd = "tofu tomo --absorptivity --fix-nan-and-inf" cmd += " --darks {} --flats {} --projections {}".format(indir[0], indir[1], indir[2]) if ctset[1] == 4: # must be equivalent to len(indir)>3 cmd += " --flats2 {}".format(indir[3]) cmd += " --output {}".format(out_pattern) cmd += " --axis {}".format(ax) cmd += " --offset {}".format(args.main_region_rotate_volume_clock) cmd += " --number {}".format(nviews) if args.step > 0.0: cmd += ' --angle {}'.format(args.step) cmd = self.check_vcrop(cmd, args.main_region_select_rows, args.main_region_first_row, args.main_region_number_rows, args.main_region_nth_row, WH[0]) cmd = self.check_8bit(cmd, args.main_region_clip_histogram, args.main_region_bit_depth, args.main_region_histogram_min, args.main_region_histogram_max) cmd = self.check_bigtif(cmd, args.main_config_save_multipage_tiff) return cmd def get_ct_proj_cmd(self, out_pattern, ax, args, nviews, WH): # CT reconstruction from pre-processed and flat-corrected projections in_proj_dir, quatsch = fmt_in_out_path( args.main_config_temp_dir, "obsolete;if-you-need-fix-it", self._fdt_names[2], False ) cmd = "tofu tomo --projections {}".format(in_proj_dir) cmd += " --output {}".format(out_pattern) cmd += " --axis {}".format(ax) cmd += " --offset {}".format(args.main_region_rotate_volume_clock) cmd += " --number {}".format(nviews) if args.step > 0.0: cmd += ' --angle {}'.format(args.step) cmd = self.check_vcrop(cmd, args.main_region_select_rows, args.main_region_first_row, args.main_region_number_rows, args.main_region_nth_row, WH[0]) cmd = self.check_8bit(cmd, args.main_region_clip_histogram, args.main_region_bit_depth, args.main_region_histogram_min, args.main_region_histogram_max) cmd = self.check_bigtif(cmd, args.main_config_save_multipage_tiff) return cmd def get_ct_sin_cmd(self, out_pattern, ax, args, nviews, WH): sinos_dir = os.path.join(args.main_config_temp_dir, 'sinos-filt') cmd = 'tofu tomo --sinograms {}'.format(sinos_dir) cmd += ' --output {}'.format(out_pattern) cmd += ' --axis {}'.format(ax) cmd += ' --offset {}'.format(args.main_region_rotate_volume_clock) if args.main_region_select_rows: cmd += ' --number {}'.format(int(args.main_region_number_rows / args.main_region_nth_row)) else: cmd += " --number {}".format(WH[0]) cmd += " --height {}".format(nviews) if args.step > 0.0: cmd += ' --angle {}'.format(args.step) cmd = self.check_8bit(cmd, args.main_region_clip_histogram, args.main_region_bit_depth, args.main_region_histogram_min, args.main_region_histogram_max) cmd = self.check_bigtif(cmd, args.main_config_save_multipage_tiff) return cmd def get_sinos_ffc_cmd(self, ctset, tmpdir, args, nviews, WH): indir = self.make_inpaths(ctset[0], ctset[1], args) in_proj_dir, out_pattern = fmt_in_out_path(args.main_config_temp_dir, ctset[0], self._fdt_names[2], False) cmd = 'tofu sinos --absorptivity --fix-nan-and-inf' cmd += ' --darks {} --flats {} '.format(indir[0], indir[1]) if ctset[1] == 4: cmd += " --flats2 {}".format(indir[3]) cmd += " --projections {}".format(in_proj_dir) cmd += " --output {}".format(os.path.join(tmpdir, "sinos/sin-%04i.tif")) cmd += " --number {}".format(nviews) cmd = self.check_vcrop(cmd, args.main_region_select_rows, args.main_region_first_row, args.main_region_number_rows, args.main_region_nth_row, WH[0]) if not args.main_filters_ring_removal_ufo_lpf: # because second RR algorithm does not know how to work with multipage tiffs cmd += " --output-bytes-per-file 0" if not args.advanced_advtofu_aux_ffc_dark_scale == "": cmd += ' --dark-scale {}'.format(args.advanced_advtofu_aux_ffc_dark_scale) if not args.advanced_advtofu_aux_ffc_flat_scale == "": cmd += ' --flat-scale {}'.format(args.advanced_advtofu_aux_ffc_flat_scale) return cmd def get_sinos_noffc_cmd(self, ctsetpath, tmpdir, args, nviews, WH): in_proj_dir, out_pattern = fmt_in_out_path( args.main_config_temp_dir, ctsetpath, self._fdt_names[2], False ) cmd = "tofu sinos" cmd += " --projections {}".format(in_proj_dir) cmd += " --output {}".format(os.path.join(tmpdir, "sinos/sin-%04i.tif")) cmd += " --number {}".format(nviews) cmd = self.check_vcrop(cmd, args.main_region_select_rows, args.main_region_first_row, args.main_region_number_rows, args.main_region_nth_row, WH[0]) if not args.main_filters_ring_removal_ufo_lpf: # because second RR algorithm does not know how to work with multipage tiffs cmd += " --output-bytes-per-file 0" return cmd def get_sinos2proj_cmd(self, args, proj_height): quatsch, out_pattern = fmt_in_out_path(args.main_config_temp_dir, 'quatsch', self._fdt_names[2], True) cmd = 'tofu sinos' cmd += ' --projections {}'.format(os.path.join(args.main_config_temp_dir, 'sinos-filt')) cmd += ' --output {}'.format(out_pattern) if not args.main_region_select_rows: cmd += ' --number {}'.format(proj_height) else: cmd += ' --number {}'.format(int(args.main_region_number_rows / args.main_region_nth_row)) return cmd def get_sinFFC_cmd(self, ctset, args, nviews, n): indir = self.make_inpaths(ctset[0], ctset[1], args) in_proj_dir, out_pattern = fmt_in_out_path(args.main_config_temp_dir, ctset[0], self._fdt_names[2]) cmd = 'bmit_sin --fix-nan' cmd += ' --darks {} --flats {} --projections {}'.format(indir[0], indir[1], in_proj_dir) if ctset[1] == 4: cmd += ' --flats2 {}'.format(indir[3]) cmd += ' --output {}'.format(os.path.dirname(out_pattern)) cmd += ' --method {}'.format(args.advanced_ffc_method) cmd += ' --multiprocessing' cmd += ' --eigen-pco-repetitions {}'.format(args.advanced_ffc_eigen_pco_reps) cmd += ' --eigen-pco-downsample {}'.format(args.advanced_ffc_eigen_pco_downsample) cmd += ' --downsample {}'.format(args.advanced_ffc_downsample) return cmd def get_pr_sinFFC_cmd(self, ctset, args, nviews, n): indir = self.make_inpaths(ctset[0], ctset[1], args) in_proj_dir, out_pattern = fmt_in_out_path( args.main_config_temp_dir, ctset[0], self._fdt_names[2]) cmd = 'bmit_sin --fix-nan' cmd += ' --darks {} --flats {} --projections {}'.format(indir[0], indir[1], in_proj_dir) if ctset[1] == 4: cmd += ' --flats2 {}'.format(indir[3]) cmd += ' --output {}'.format(os.path.dirname(out_pattern)) cmd += ' --method {}'.format(args.advanced_ffc_method) cmd += ' --multiprocessing' cmd += ' --eigen-pco-repetitions {}'.format(args.advanced_ffc_eigen_pco_reps) cmd += ' --eigen-pco-downsample {}'.format(args.advanced_ffc_eigen_pco_downsample) cmd += ' --downsample {}'.format(args.advanced_ffc_downsample) return cmd def get_pr_tofu_cmd_sinFFC(self, ctset, args, nviews, WH): # indir will format paths to flats darks and tomo2 correctly even if they were # pre-processed, however path to the input directory with projections # cannot be formatted with that command correctly # indir = self.make_inpaths(ctset[0], ctset[1]) # so we need a separate "universal" command which considers all previous steps in_proj_dir, out_pattern = fmt_in_out_path(args.main_config_temp_dir, ctset[0], self._fdt_names[2]) # Phase retrieval cmd = 'tofu preprocess --delta 1e-6' cmd += ' --energy {} --propagation-distance {}' \ ' --pixel-size {} --regularization-rate {:0.2f}' \ .format(args.main_pr_photon_energy, args.main_pr_detector_distance, args.main_pr_pixel_size, args.main_pr_delta_beta_ratio) cmd += ' --projections {}'.format(in_proj_dir) cmd += ' --output {}'.format(out_pattern) cmd += ' --projection-crop-after filter' return cmd def get_pr_tofu_cmd(self, ctset, args, nviews, WH): # indir will format paths to flats darks and tomo2 correctly even if they were # pre-processed, however path to the input directory with projections # cannot be formatted with that command correctly indir = self.make_inpaths(ctset[0], ctset[1], args) # so we need a separate "universal" command which considers all previous steps in_proj_dir, out_pattern = fmt_in_out_path(args.main_config_temp_dir, ctset[0], self._fdt_names[2]) cmd = 'tofu preprocess --fix-nan-and-inf --projection-filter none --delta 1e-6' cmd += ' --darks {} --flats {} --projections {}'.format(indir[0], indir[1], in_proj_dir) if ctset[1] == 4: cmd += ' --flats2 {}'.format(indir[3]) cmd += ' --output {}'.format(out_pattern) cmd += ' --energy {} --propagation-distance {}' \ ' --pixel-size {} --regularization-rate {:0.2f}' \ .format(args.main_pr_photon_energy, args.main_pr_detector_distance, args.main_pr_pixel_size, args.main_pr_delta_beta_ratio) if not args.advanced_advtofu_aux_ffc_dark_scale == "": cmd += ' --dark-scale {}'.format(args.advanced_advtofu_aux_ffc_dark_scale) if not args.advanced_advtofu_aux_ffc_flat_scale == "": cmd += ' --flat-scale {}'.format(args.advanced_advtofu_aux_ffc_flat_scale) return cmd def get_reco_cmd(self, ctset, out_pattern, ax, args, nviews, WH, ffc, PR): # direct CT reconstruction from input dir to output dir; # or CT reconstruction after preprocessing only indir = self.make_inpaths(ctset[0], ctset[1], args) # correct location of proj folder in case if prepro was done in_proj_dir, quatsch = fmt_in_out_path(args.main_config_temp_dir, ctset[0], self._fdt_names[2], False) # Laminography cmd = '' if args.advanced_advtofu_extended_settings is True: cmd += self.check_lamino(cmd, args) elif args.advanced_advtofu_extended_settings is False: cmd = "tofu reco" cmd += ' --overall-angle 180' ############## cmd += ' --projections {}'.format(in_proj_dir) cmd += ' --output {}'.format(out_pattern) if ffc: cmd += ' --fix-nan-and-inf' cmd += ' --darks {} --flats {}'.format(indir[0], indir[1]) if ctset[1] == 4: # must be equivalent to len(indir)>3 cmd += ' --flats2 {}'.format(indir[3]) if not PR: cmd += ' --absorptivity' if not args.advanced_advtofu_aux_ffc_dark_scale == "": cmd += ' --dark-scale {}'.format(args.advanced_advtofu_aux_ffc_dark_scale) if not args.advanced_advtofu_aux_ffc_flat_scale == "": cmd += ' --flat-scale {}'.format(args.advanced_advtofu_aux_ffc_flat_scale) if PR: cmd += ( " --disable-projection-crop" " --delta 1e-6" " --energy {} --propagation-distance {}" " --pixel-size {} --regularization-rate {:0.2f}" \ .format(args.main_pr_photon_energy, args.main_pr_detector_distance, args.main_pr_pixel_size, args.main_pr_delta_beta_ratio) ) cmd += " --center-position-x {}".format(ax) # if args.nviews==0: cmd += " --number {}".format(nviews) # elif args.nviews>0: # cmd += ' --number {}'.format(args.nviews) cmd += ' --volume-angle-z {:0.5f}'.format(args.main_region_rotate_volume_clock) # rows-slices to be reconstructed # full ROI b = int(np.ceil(WH[0] / 2.0)) a = -int(WH[0] / 2.0) c = 1 if args.main_region_select_rows: if args.main_filters_ring_removal: h2 = args.main_region_number_rows / args.main_region_nth_row / 2.0 b = np.ceil(h2) a = -int(h2) else: h2 = int(WH[0] / 2.0) a = args.main_region_first_row - h2 b = args.main_region_first_row + args.main_region_number_rows - h2 c = args.main_region_nth_row cmd += ' --region={},{},{}'.format(a, b, c) # crop of reconstructed slice in the axial plane b = WH[1] / 2 if args.main_region_crop_slices: cmd += ' --x-region={},{},{}'.format(args.main_region_crop_x - b, args.main_region_crop_x + args.main_region_crop_width - b, 1) cmd += ' --y-region={},{},{}'.format(args.main_region_crop_y - b, args.main_region_crop_y + args.main_region_crop_height - b, 1) # cmd = self.check_vcrop(cmd, args.main_region_select_rows, args.main_region_first_row, args.main_region_number_rows, args.main_region_nth_row, WH[0]) cmd = self.check_8bit(cmd, args.main_region_clip_histogram, args.main_region_bit_depth, args.main_region_histogram_min, args.main_region_histogram_max) cmd = self.check_bigtif(cmd, args.main_config_save_multipage_tiff) # Optimization cmd += ' --slice-memory-coeff={}'.format(args.advanced_optimize_slice_mem_coeff) if args.advanced_optimize_verbose_console: cmd += ' --verbose' if not args.advanced_optimize_num_gpus == '': cmd += ' --gpus {}'.format(args.advanced_optimize_num_gpus) if not args.advanced_optimize_slices_per_device == '': cmd += ' --slices-per-device {}'.format(args.advanced_optimize_slices_per_device) return cmd def get_reco_cmd_sinFFC(self, ctset, out_pattern, ax, args, nviews, WH, ffc, PR): # direct CT reconstruction from input dir to output dir; # or CT reconstruction after preprocessing only indir = self.make_inpaths(ctset[0], ctset[1], args) # correct location of proj folder in case if prepro was done in_proj_dir, quatsch = fmt_in_out_path(args.main_config_temp_dir, ctset[0], self._fdt_names[2], False) # in_proj_dir, quatsch = fmt_in_out_path(args.tmpdir,args.indir, self._fdt_names[2], False) # indir[2]=os.path.join(os.path.split(indir[2])[0], os.path.split(in_proj_dir)[1]) # format command cmd = "tofu reco" # Laminography if args.advanced_advtofu_extended_settings: cmd += self.check_lamino(cmd, args) else: cmd += " --overall-angle 180" ############## cmd += " --projections {}".format(in_proj_dir) cmd += " --output {}".format(out_pattern) if PR: cmd += ' --disable-projection-crop' \ ' --delta 1e-6' \ ' --energy {} --propagation-distance {}' \ ' --pixel-size {} --regularization-rate {:0.2f}' \ .format(args.main_pr_photon_energy, args.main_pr_detector_distance, args.main_pr_pixel_size, args.main_pr_delta_beta_ratio) cmd += ' --center-position-x {}'.format(ax) # if args.nviews==0: cmd += " --number {}".format(nviews) # elif args.nviews>0: # cmd += ' --number {}'.format(args.nviews) cmd += " --volume-angle-z {:0.5f}".format(args.main_region_rotate_volume_clock) # rows-slices to be reconstructed # full ROI b = int(np.ceil(WH[0] / 2.0)) a = -int(WH[0] / 2.0) c = 1 if args.main_region_select_rows: if args.main_filters_ring_removal: h2 = args.main_region_number_rows / args.main_region_nth_row / 2.0 b = np.ceil(h2) a = -int(h2) else: h2 = int(WH[0] / 2.0) a = args.main_region_first_row - h2 b = args.main_region_first_row + args.main_region_number_rows - h2 c = args.main_region_nth_row cmd += ' --region={},{},{}'.format(a, b, c) # crop of reconstructed slice in the axial plane b = WH[1] / 2 if args.main_region_crop_slices: cmd += ' --x-region={},{},{}'.format(args.main_region_crop_x - b, args.main_region_crop_x + args.main_region_crop_width - b, 1) cmd += ' --y-region={},{},{}'.format(args.main_region_crop_y - b, args.main_region_crop_y + args.main_region_crop_height - b, 1) # cmd = self.check_vcrop(cmd, args.main_region_select_rows, args.main_region_first_row, args.main_region_number_rows, args.main_region_nth_row, WH[0]) cmd = self.check_8bit(cmd, args.main_region_clip_histogram, args.main_region_bit_depth, args.main_region_histogram_min, args.main_region_histogram_max) cmd = self.check_bigtif(cmd, args.main_config_save_multipage_tiff) # Optimization cmd += ' --slice-memory-coeff={}'.format(args.advanced_optimize_slice_mem_coeff) if args.advanced_optimize_verbose_console: cmd += ' --verbose' if not args.advanced_optimize_num_gpus == '': cmd += ' --gpus {}'.format(args.advanced_optimize_num_gpus) if not args.advanced_optimize_slices_per_device == '': cmd += ' --slices-per-device {}'.format(args.advanced_optimize_slices_per_device) return cmd tofu-0.12.0/tofu/ez/ufo_cmd_gen.py000066400000000000000000000255161423713721100167730ustar00rootroot00000000000000#!/bin/python """ Created on Apr 6, 2018 @author: gasilos """ import glob import os from tofu.util import get_filenames, read_image, next_power_of_two from tofu.ez.util import enquote def fmt_in_out_path(tmpdir, indir, raw_proj_dir_name, croutdir=True): # suggests input and output path to directory with proj # depending on number of processing steps applied so far li = sorted(glob.glob(os.path.join(tmpdir, "proj-step*"))) proj_dirs = [d for d in li if os.path.isdir(d)] Nsteps = len(proj_dirs) in_proj_dir, out_proj_dir = "qqq", "qqq" if Nsteps == 0: # no projections in temporary directory in_proj_dir = os.path.join(indir, raw_proj_dir_name) out_proj_dir = "proj-step1" elif Nsteps > 0: # there are directories proj-stepX in tmp dir in_proj_dir = proj_dirs[-1] out_proj_dir = "{}{}".format(in_proj_dir[:-1], Nsteps + 1) else: raise ValueError("Something is wrong with in/out filenames") # physically create output directory tmp = os.path.join(tmpdir, out_proj_dir) if croutdir and not os.path.exists(tmp): os.makedirs(tmp) # return names of input directory and output pattern with abs path return in_proj_dir, os.path.join(tmp, "proj-%04i.tif") class ufo_cmds(object): """ Generates partially formatted ufo-launch and tofu commands Parameters are included in the string; pathnames must be added """ def __init__(self, fol): self._fdt_names = fol def make_inpaths(self, lvl0, flats2, args): """ Creates a list of paths to flats/darks/tomo directories :param lvl0: Root of directory containing flats/darks/tomo :param flats2: The type of directory: 3 contains flats/darks/tomo 4 contains flats/darks/tomo/flats2 :return: List of paths to the directories containing darks/flats/tomo and flats2 (if used) """ indir = [] # If using flats/darks/flats2 in same dir as tomo if not args.main_config_common_flats_darks: for i in self._fdt_names[:3]: indir.append(os.path.join(lvl0, i)) if flats2 - 3: indir.append(os.path.join(lvl0, self._fdt_names[3])) return indir # If using common flats/darks/flats2 across multiple reconstructions elif args.main_config_common_flats_darks: indir.append(args.main_config_darks_path) indir.append(args.main_config_flats_path) indir.append(os.path.join(lvl0, self._fdt_names[2])) if args.main_config_flats2_checkbox: indir.append(args.main_config_flats2_path) return indir def check_vcrop(self, cmd, vcrop, y, yheight, ystep): if vcrop: cmd += " --y {} --height {} --y-step {}".format(y, yheight, ystep) return cmd def check_bigtif(self, cmd, swi): if not swi: cmd += " bytes-per-file=0" return cmd def get_pr_ufo_cmd(self, args, nviews, WH): # in_proj_dir, out_pattern = fmt_in_out_path(args.main_config_temp_dir,args.main_config_input_dir,self._fdt_names[2]) in_proj_dir, out_pattern = fmt_in_out_path(args.main_config_temp_dir, "quatsch", self._fdt_names[2]) cmds = [] pad_width = next_power_of_two(WH[1] + 50) pad_height = next_power_of_two(WH[0] + 50) pad_x = (pad_width - WH[1]) / 2 pad_y = (pad_height - WH[0]) / 2 cmd = 'ufo-launch read path={} height={} number={}'.format(in_proj_dir, WH[0], nviews) cmd += ' ! pad x={} width={} y={} height={}'.format(pad_x, pad_width, pad_y, pad_height) cmd += ' addressing-mode=clamp_to_edge' cmd += ' ! fft dimensions=2 ! retrieve-phase' cmd += ' energy={} distance={} pixel-size={} regularization-rate={:0.2f}' \ .format(args.main_pr_photon_energy, args.main_pr_detector_distance, args.main_pr_pixel_size, args.main_pr_delta_beta_ratio) cmd += ' ! ifft dimensions=2 crop-width={} crop-height={}' \ .format(pad_width, pad_height) cmd += ' ! crop x={} width={} y={} height={}'.format(pad_x, WH[1], pad_y, WH[0]) cmd += ' ! opencl kernel=\'absorptivity\' ! opencl kernel=\'fix_nan_and_inf\' !' cmd += ' write filename={}'.format(enquote(out_pattern)) cmds.append(cmd) if not args.main_config_keep_temp: cmds.append('rm -rf {}'.format(in_proj_dir)) return cmds def get_filter1d_sinos_cmd(self, tmpdir, RR, nviews): sin_in = os.path.join(tmpdir, 'sinos') out_pattern = os.path.join(tmpdir, 'sinos-filt/sin-%04i.tif') pad_height = next_power_of_two(nviews + 500) pad_y = (pad_height - nviews) / 2 cmd = 'ufo-launch read path={}'.format(sin_in) cmd += ' ! pad y={} height={}'.format(pad_y, pad_height) cmd += ' addressing-mode=clamp_to_edge' cmd += ' ! transpose ! fft dimensions=1 ! filter-stripes1d strength={}'.format(RR) cmd += ' ! ifft dimensions=1 ! transpose' cmd += ' ! crop y={} height={}'.format(pad_y, nviews) cmd += ' ! write filename={}'.format(enquote(out_pattern)) return cmd def get_filter2d_sinos_cmd(self, tmpdir, sig_hor, sig_ver, nviews, w): sin_in = os.path.join(tmpdir, "sinos") out_pattern = os.path.join(tmpdir, "sinos-filt/sin-%04i.tif") pad_height = next_power_of_two(nviews + 500) pad_y = (pad_height - nviews) / 2 pad_width = next_power_of_two(w + 500) pad_x = (pad_width - w) / 2 cmd = "ufo-launch read path={}".format(sin_in) cmd += " ! pad x={} width={} y={} height={}".format(pad_x, pad_width, pad_y, pad_height) cmd += " addressing-mode=mirrored_repeat" cmd += " ! fft dimensions=2 ! filter-stripes horizontal-sigma={} vertical-sigma={}".format( sig_hor, sig_ver ) cmd += " ! ifft dimensions=2 crop-width={} crop-height={}".format(pad_width, pad_height) cmd += " ! crop x={} width={} y={} height={}".format(pad_x, w, pad_y, nviews) cmd += " ! write filename={}".format(enquote(out_pattern)) return cmd def get_pre_cmd(self, ctset, pre_cmd, tmpdir, dryrun, args): indir = self.make_inpaths(ctset[0], ctset[1], args) outdir = self.make_inpaths(tmpdir, ctset[1], args) # add index to the name of the output directory with projections # if enabled preprocessing is always the first step outdir[2] = os.path.join(tmpdir, "proj-step1") # we also must create this directory to format paths correctly if not os.path.exists(outdir[2]): os.makedirs(outdir[2]) cmds = [] for i, fol in enumerate(indir): in_pattern = os.path.join(fol, "*.tif") out_pattern = os.path.join(outdir[i], "frame-%04i.tif") cmds.append("ufo-launch") cmds[i] += " read path={} ! ".format(enquote(in_pattern)) cmds[i] += pre_cmd cmds[i] += " ! write filename={}".format(enquote(out_pattern)) return cmds def get_inp_cmd(self, ctset, tmpdir, args, N, nviews, any_flat): indir = self.make_inpaths(ctset[0], ctset[1], args) outdir = self.make_inpaths(tmpdir, ctset[1], args) cmds = [] ######### CREATE MASK ######### mask_file = os.path.join(tmpdir, "mask.tif") # generate mask cmd = 'tofu find-large-spots --images {}'.format(any_flat) cmd += ' --spot-threshold {} --gauss-sigma {}'.format( args.main_filters_remove_spots_threshold, args.main_filters_remove_spots_blur_sigma) cmd += ' --output {} --output-bytes-per-file 0'.format(mask_file) cmds.append(cmd) ######### FLAT-CORRECT ######### in_proj_dir, out_pattern = fmt_in_out_path(args.main_config_temp_dir, ctset[0], self._fdt_names[2]) if args.advanced_ffc_sinFFC: cmd = 'bmit_sin --fix-nan' cmd += ' --darks {} --flats {}'.format(indir[0], indir[1]) cmd += ' --projections {}'.format(in_proj_dir) cmd += ' --output {}'.format(os.path.dirname(out_pattern)) cmd += ' --multiprocessing' #cmd += ' --output {}'.format(out_pattern) if ctset[1] == 4: cmd += ' --flats2 {}'.format(indir[3]) # Add options for eigen-pco-repetitions etc. cmd += ' --eigen-pco-repetitions {}'.format(args.advanced_ffc_eigen_pco_reps) cmd += ' --eigen-pco-downsample {}'.format(args.advanced_ffc_eigen_pco_downsample) cmd += ' --downsample {}'.format(args.advanced_ffc_downsample) #if not args.main_pr_phase_retrieval: # cmd += ' --absorptivity' cmds.append(cmd) elif not args.advanced_ffc_sinFFC: cmd = 'tofu flatcorrect --fix-nan-and-inf' cmd += ' --darks {} --flats {}'.format(indir[0], indir[1]) cmd += ' --projections {}'.format(in_proj_dir) cmd += ' --output {}'.format(out_pattern) if ctset[1] == 4: cmd += ' --flats2 {}'.format(indir[3]) if not args.main_pr_phase_retrieval: cmd += ' --absorptivity' if not args.advanced_advtofu_aux_ffc_dark_scale == "": cmd += ' --dark-scale {}'.format(args.advanced_advtofu_aux_ffc_dark_scale) if not args.advanced_advtofu_aux_ffc_flat_scale == "": cmd += ' --flat-scale {}'.format(args.advanced_advtofu_aux_ffc_flat_scale) cmds.append(cmd) if not args.main_config_keep_temp and args.main_config_preprocess: cmds.append('rm -rf {}'.format(indir[0])) cmds.append('rm -rf {}'.format(indir[1])) cmds.append('rm -rf {}'.format(in_proj_dir)) if len(indir) > 3: cmds.append("rm -rf {}".format(indir[3])) ######### INPAINT ######### in_proj_dir, out_pattern = fmt_in_out_path(args.main_config_temp_dir, ctset[0], self._fdt_names[2]) cmd = "ufo-launch [read path={} height={} number={}".format(in_proj_dir, N, nviews) cmd += ", read path={}]".format(mask_file) cmd += " ! horizontal-interpolate ! " cmd += "write filename={}".format(enquote(out_pattern)) cmds.append(cmd) if not args.main_config_keep_temp: cmds.append("rm -rf {}".format(in_proj_dir)) return cmds def get_crop_sli(self, out_pattern, args): cmd = 'ufo-launch read path={}/*.tif ! '.format(os.path.dirname(out_pattern)) cmd += 'crop x={} width={} y={} height={} ! '. \ format(args.main_region_crop_x, args.main_region_crop_width, args.main_region_crop_y, args.main_region_crop_height) cmd += 'write filename={}'.format(out_pattern) if args.main_region_clip_histogram: cmd += ' bits=8 rescale=False' return cmd tofu-0.12.0/tofu/ez/util.py000066400000000000000000000175301423713721100155000ustar00rootroot00000000000000""" Created on Apr 20, 2020 @author: gasilos """ import os import tifffile import yaml import numpy as np from tofu.util import get_filenames, get_first_filename, get_image_shape, read_image import tofu.ez.params as parameters def get_dims(pth): # get number of projections and projections dimensions first_proj = get_first_filename(pth) multipage = False try: shape = get_image_shape(first_proj) except: raise ValueError("Failed to determine size and number of projections in {}".format(pth)) if len(shape) == 2: # single page input return len(get_filenames(pth)), [shape[-2], shape[-1]], multipage elif len(shape) == 3: # multipage input nviews = 0 for i in get_filenames(pth): nviews += get_image_shape(i)[0] multipage = True return nviews, [shape[-2], shape[-1]], multipage return -6, [-6, -6] def bad_vert_ROI(multipage, path2proj, y, height): if multipage: with tifffile.TiffFile(get_filenames(path2proj)[0]) as tif: proj = tif.pages[0].asarray().astype(np.float) else: proj = read_image(get_filenames(path2proj)[0]).astype(np.float) y_region = slice(y, min(y + height, proj.shape[0]), 1) proj = proj[y_region, :] if proj.shape[0] == 0: return True else: return False def make_copy_of_flat(flatdir, flat_copy_name, dryrun): first_flat_file = get_first_filename(flatdir) try: shape = get_image_shape(first_flat_file) except: raise ValueError("Failed to determine size and number of flats in {}".format(flatdir)) cmd = "" if len(shape) == 2: last_flat_file = get_filenames(flatdir)[-1] cmd = "cp {} {}".format(last_flat_file, flat_copy_name) else: flat = read_image(get_filenames(flatdir)[-1])[-1] if dryrun: cmd = 'echo Will save a copy of flat into "{}"'.format(flat_copy_name) else: tifffile.imsave(flat_copy_name, flat) return cmd def clean_tmp_dirs(tmpdir, fdt_names): tmp_pattern = ["proj", "sino", "mask", "flat", "dark", "radi"] tmp_pattern += fdt_names # clean directories in tmpdir if their names match pattern if os.path.exists(tmpdir): for filename in os.listdir(tmpdir): if filename[:4] in tmp_pattern: os.system("rm -rf {}".format(os.path.join(tmpdir, filename))) def enquote(string, escape=False): addition = '\\"' if escape else '"' return addition + string + addition def save_params(args, ctsetname, ax, nviews, WH): if not args.main_config_dry_run and not os.path.exists(args.main_config_output_dir): os.makedirs(args.main_config_output_dir) tmp = os.path.join(args.main_config_output_dir, ctsetname) if not args.main_config_dry_run and not os.path.exists(tmp): os.makedirs(tmp) if not args.main_config_dry_run and args.main_config_save_params: # Dump the params .yaml file try: yaml_output_filepath = os.path.join(tmp, "parameters.yaml") yaml_output = open(yaml_output_filepath, "w") yaml.dump(parameters.params, yaml_output) except FileNotFoundError: print("Something went wrong when exporting the .yaml parameters file") # Dump the reco.params output file fname = os.path.join(tmp, 'reco.params') f = open(fname, 'w') f.write('*** General ***\n') f.write('Input directory {}\n'.format(args.main_config_input_dir)) if ctsetname == '': ctsetname = '.' f.write('CT set {}\n'.format(ctsetname)) if args.main_cor_axis_search_method == 1 or args.main_cor_axis_search_method == 2: f.write('Center of rotation {} (auto estimate)\n'.format(ax)) else: f.write('Center of rotation {} (user defined)\n'.format(ax)) f.write('Dimensions of projections {} x {} (height x width)\n'.format(WH[0], WH[1])) f.write('Number of projections {}\n'.format(nviews)) f.write('*** Preprocessing ***\n') tmp = 'None' if args.main_config_preprocess: tmp = args.main_config_preprocess_command f.write(' '+tmp+'\n') f.write('*** Image filters ***\n') if args.main_filters_remove_spots: f.write(' Remove large spots enabled\n') f.write(' threshold {}\n'.format(args.main_filters_remove_spots_threshold)) f.write(' sigma {}\n'.format(args.main_filters_remove_spots_blur_sigma)) else: f.write(' Remove large spots disabled\n') if args.main_pr_phase_retrieval: f.write(' Phase retrieval enabled\n') f.write(' energy {} keV\n'.format(args.main_pr_photon_energy)) f.write(' pixel size {:0.1f} um\n'.format(args.main_pr_pixel_size * 1e6)) f.write(' sample-detector distance {} m\n'.format(args.main_pr_detector_distance)) f.write(' delta/beta ratio {:0.0f}\n'.format(10 ** args.main_pr_delta_beta_ratio)) else: f.write(' Phase retrieval disabled\n') f.write('*** Ring removal ***\n') if args.main_filters_ring_removal: if args.main_filters_ring_removal_ufo_lpf: tmp = '2D' if args.main_filters_ring_removal_ufo_lpf_1d_or_2d: tmp = '1D' f.write(' RR with ufo {} stripes filter\n'.format(tmp)) f.write(f' sigma horizontal {args.main_filters_ring_removal_ufo_lpf_sigma_horizontal}') f.write(f' sigma vertical {args.main_filters_ring_removal_ufo_lpf_sigma_vertical}') else: if args.main_filters_ring_removal_sarepy_wide: tmp = ' RR with ufo sarepy remove wide filter, ' tmp += 'window {}, SNR {}\n'.format( args.main_filters_ring_removal_sarepy_window, args.main_filters_ring_removal_sarepy_SNR) f.write(tmp) f.write(' ' 'RR with ufo sarepy sorting filter, window {}\n'. format(args.main_filters_ring_removal_sarepy_window_size) ) else: f.write('RR disabled\n') f.write('*** Region of interest ***\n') if args.main_region_select_rows: f.write('Vertical ROI defined\n') f.write(' first row {}\n'.format(args.main_region_first_row)) f.write(' height {}\n'.format(args.main_region_number_rows)) f.write(' reconstruct every {}th row\n'.format(args.main_region_nth_row)) else: f.write('Vertical ROI: all rows\n') if args.main_region_crop_slices: f.write('ROI in slice plane defined\n') f.write(' x {}\n'.format(args.main_region_crop_x)) f.write(' width {}\n'.format(args.main_region_crop_width)) f.write(' y {}\n'.format(args.main_region_crop_y)) f.write(' height {}\n'.format(args.main_region_crop_height)) else: f.write('ROI in slice plane not defined\n') f.write('*** Reconstructed values ***\n') if args.main_region_clip_histogram: f.write(' {} bit\n'.format(args.main_region_bit_depth)) f.write(' Min value in 32-bit histogram {}\n'.format(args.main_region_histogram_min)) f.write(' Max value in 32-bit histogram {}\n'.format(args.main_region_histogram_max)) else: f.write(' 32bit, histogram untouched\n') f.write('*** Optional reco parameters ***\n') if args.main_region_rotate_volume_clock > 0: f.write(' Rotate volume by: {:0.3f} deg\n'.format(args.main_region_rotate_volume_clock)) f.close() tofu-0.12.0/tofu/ez/yaml_in_out.py000066400000000000000000000007721423713721100170420ustar00rootroot00000000000000import yaml import logging LOG = logging.getLogger(__name__) class Yaml_IO: def read_yaml(self, filePath): with open(filePath) as f: data = yaml.load(f, Loader=yaml.FullLoader) LOG.debug(data) return data def write_yaml(self, filePath, params): try: file = open(filePath, "w") except FileNotFoundError: LOG.debug("No filename given") else: yaml.dump(params, file) file.close() tofu-0.12.0/tofu/find_large_spots.py000066400000000000000000000043311423713721100174220ustar00rootroot00000000000000import logging from gi.repository import Ufo from tofu.util import set_node_props, determine_shape, setup_read_task, setup_padding from tofu.tasks import get_task, get_writer LOG = logging.getLogger(__name__) def find_large_spots(args): graph = Ufo.TaskGraph() sched = Ufo.FixedScheduler() reader = get_task('read') writer = get_writer(args) if args.gauss_sigma and args.blurred_output: broadcast = Ufo.CopyTask() blurred_writer = get_task('write') if hasattr(blurred_writer.props, 'bytes_per_file'): blurred_writer.props.bytes_per_file = 0 if hasattr(blurred_writer.props, 'tiff_bigtiff'): blurred_writer.props.tiff_bigtiff = False blurred_writer.props.filename = args.blurred_output find = get_task('find-large-spots') set_node_props(find, args) find.props.addressing_mode = args.find_large_spots_padding_mode set_node_props(reader, args) setup_read_task(reader, args.images, args) if args.gauss_sigma: reader_2 = get_task('read') set_node_props(reader_2, args) setup_read_task(reader_2, args.images, args) pad = get_task('pad') crop = get_task('crop') opencl = get_task('opencl', kernel='diff', filename='opencl.cl') width, height = determine_shape(args, path=args.images) gauss_size = int(10 * args.gauss_sigma) setup_padding(pad, width, height, args.find_large_spots_padding_mode, crop=crop, pad_width=gauss_size, pad_height=gauss_size) LOG.debug("Gauss size: %d", gauss_size) blur = get_task('blur', sigma=args.gauss_sigma, size=gauss_size) graph.connect_nodes_full(reader, opencl, 0) graph.connect_nodes(reader_2, pad) graph.connect_nodes(pad, blur) graph.connect_nodes(blur, crop) graph.connect_nodes_full(crop, opencl, 1) if args.blurred_output: graph.connect_nodes(opencl, broadcast) graph.connect_nodes(broadcast, blurred_writer) source = broadcast else: source = opencl graph.connect_nodes(source, find) else: graph.connect_nodes(reader, find) graph.connect_nodes(find, writer) sched.run(graph) tofu-0.12.0/tofu/flow/000077500000000000000000000000001423713721100144745ustar00rootroot00000000000000tofu-0.12.0/tofu/flow/__init__.py000066400000000000000000000000001423713721100165730ustar00rootroot00000000000000tofu-0.12.0/tofu/flow/composites/000077500000000000000000000000001423713721100166615ustar00rootroot00000000000000tofu-0.12.0/tofu/flow/composites/ffc-links.cm000066400000000000000000000221021423713721100210530ustar00rootroot00000000000000{ "name": "CFlatFieldCorrect", "caption": "CFlatFieldCorrect", "models": { "Flat Field Correct": { "model": { "caption": "Flat Field Correct", "properties": { "fix-nan-and-inf": [ true, true ], "absorption-correct": [ true, true ], "sinogram-input": [ false, false ], "dark-scale": [ 1.0, false ], "flat-scale": [ 1.0, false ] } }, "visible": true, "position": { "x": 1253.0, "y": 490.0 }, "name": "flat_field_correct" }, "Read 2": { "model": { "caption": "Read 2", "properties": { "path": [ ".", true ], "start": [ 0, false ], "number": [ 4294967295, true ], "step": [ 1, false ], "y": [ 0, false ], "height": [ 0, false ], "y-step": [ 1, false ], "convert": [ true, false ], "raw-width": [ 0, false ], "raw-height": [ 0, false ], "raw-bitdepth": [ 0, false ], "raw-pre-offset": [ 0, false ], "raw-post-offset": [ 0, false ], "type": [ "unspecified", false ], "retries": [ 0, false ], "retry-timeout": [ 1, false ] } }, "visible": true, "position": { "x": 417.0, "y": 504.0 }, "name": "read" }, "Average": { "model": { "caption": "Average", "properties": { "number": [ 4294967295, true ] } }, "visible": true, "position": { "x": 822.0, "y": 508.0 }, "name": "average" }, "Read 3": { "model": { "caption": "Read 3", "properties": { "path": [ ".", true ], "start": [ 0, false ], "number": [ 4294967295, true ], "step": [ 1, false ], "y": [ 0, false ], "height": [ 0, false ], "y-step": [ 1, false ], "convert": [ true, false ], "raw-width": [ 0, false ], "raw-height": [ 0, false ], "raw-bitdepth": [ 0, false ], "raw-pre-offset": [ 0, false ], "raw-post-offset": [ 0, false ], "type": [ "unspecified", false ], "retries": [ 0, false ], "retry-timeout": [ 1, false ] } }, "visible": true, "position": { "x": 413.0, "y": 735.0 }, "name": "read" }, "Average 2": { "model": { "caption": "Average 2", "properties": { "number": [ 4294967295, true ] } }, "visible": true, "position": { "x": 822.0, "y": 741.0 }, "name": "average" }, "Read": { "model": { "caption": "Read", "properties": { "path": [ ".", true ], "start": [ 0, false ], "number": [ 23212, true ], "step": [ 1, false ], "y": [ 0, false ], "height": [ 0, false ], "y-step": [ 1, false ], "convert": [ true, false ], "raw-width": [ 0, false ], "raw-height": [ 0, false ], "raw-bitdepth": [ 0, false ], "raw-pre-offset": [ 0, false ], "raw-post-offset": [ 0, false ], "type": [ "unspecified", false ], "retries": [ 0, false ], "retry-timeout": [ 1, false ] } }, "visible": true, "position": { "x": 418.0, "y": 245.0 }, "name": "read" } }, "connections": [ [ "Read", 0, "Flat Field Correct", 0 ], [ "Average", 0, "Flat Field Correct", 1 ], [ "Average 2", 0, "Flat Field Correct", 2 ], [ "Read 2", 0, "Average", 0 ], [ "Read 3", 0, "Average 2", 0 ] ], "links": [ [ [ "Read 2", "number" ], [ "Average", "number" ] ], [ [ "Read 3", "number" ], [ "Average 2", "number" ] ] ] } tofu-0.12.0/tofu/flow/composites/pr.cm000066400000000000000000000115161423713721100176270ustar00rootroot00000000000000{ "name": "CPhaseRetrieve", "caption": "CPhaseRetrieve", "models": { "Fft": { "model": { "caption": "Fft", "properties": { "auto-zeropadding": [ true, true ], "dimensions": [ 2, true ], "size-x": [ 1, true ], "size-y": [ 1, true ], "size-z": [ 1, true ] } }, "visible": true, "position": { "x": 112.0, "y": 245.0 }, "name": "fft" }, "Ifft": { "model": { "caption": "Ifft", "properties": { "dimensions": [ 2, true ], "crop-width": [ -1, true ], "crop-height": [ -1, true ] } }, "visible": true, "position": { "x": 772.0, "y": 250.0 }, "name": "ifft" }, "Retrieve Phase": { "model": { "caption": "Retrieve Phase", "num-inputs": 1, "properties": { "method": [ "tie", true ], "energy": [ 20.0, true ], "distance": [ 0.0, true ], "distance-x": [ 0.0, true ], "distance-y": [ 0.0, true ], "pixel-size": [ 7.500000265281415e-07, true ], "regularization-rate": [ 2.5, true ], "thresholding-rate": [ 0.10000000149011612, true ], "frequency-cutoff": [ 3.4028234663852886e+38, true ], "output-filter": [ false, true ] } }, "visible": true, "position": { "x": 544.0, "y": 515.0 }, "name": "retrieve_phase" }, "Pad": { "model": { "caption": "Pad", "properties": { "width": [ 0, true ], "height": [ 0, true ], "x": [ 0, true ], "y": [ 0, true ], "addressing-mode": [ "clamp", true ] } }, "visible": true, "position": { "x": 0.0, "y": 570.0 }, "name": "pad" } }, "connections": [ [ "Pad", 0, "Fft", 0 ], [ "Fft", 0, "Retrieve Phase", 0 ], [ "Retrieve Phase", 0, "Ifft", 0 ] ], "links": [ [ [ "Fft", "dimensions" ], [ "Ifft", "dimensions" ] ], [ [ "Fft", "size-x" ], [ "Pad", "width" ] ], [ [ "Fft", "size-y" ], [ "Pad", "height" ] ] ] } tofu-0.12.0/tofu/flow/config.json000066400000000000000000000067711423713721100166470ustar00rootroot00000000000000{ "models": { "average": { "hidden-properties": [ "number" ] }, "flat-field-correct": { "port-captions": { "input": { "0": "radios", "1": "darks", "2": "flats" }, "output": { "0": "" } }, "hidden-properties": [ "sinogram-input", "dark-scale", "flat-scale" ] }, "general-backproject": { "hidden-properties": [ "z", "burst", "source-position-x", "source-position-y", "source-position-z", "detector-position-x", "detector-position-y", "detector-position-z", "detector-angle-x", "detector-angle-y", "detector-angle-z", "axis-angle-x", "axis-angle-y", "axis-angle-z", "volume-angle-x", "volume-angle-y", "volume-angle-z", "compute-type", "result-type", "store-type", "addressing-mode", "gray-map-min", "gray-map-max" ], "range-properties": { "region": [3, true], "x-region": [3, true], "y-region": [3, true], "center-position-x": [null, true], "center-position-z": [null, true], "source-position-x": [null, true], "source-position-y": [null, true], "source-position-z": [null, true], "detector-position-x": [null, true], "detector-position-y": [null, true], "detector-position-z": [null, true], "detector-angle-x": [null, true], "detector-angle-y": [null, true], "detector-angle-z": [null, true], "axis-angle-x": [null, true], "axis-angle-y": [null, true], "axis-angle-z": [null, true], "volume-angle-x": [null, true], "volume-angle-y": [null, true], "volume-angle-z": [null, true] } }, "horizontal-interpolate": { "port-captions": { "input": { "0": "image", "1": "mask" }, "output": { "0": "" } } }, "read": { "hidden-properties": [ "start", "step", "y", "height", "y-step", "convert", "raw-width", "raw-height", "raw-bitdepth", "raw-pre-offset", "raw-post-offset", "type", "retries", "retry-timeout" ] }, "write": { "hidden-properties": [ "counter-start", "counter-step", "bytes-per-file", "append", "bits", "minimum", "maximum", "rescale", "jpeg-quality", "tiff-bigtiff" ] } } } tofu-0.12.0/tofu/flow/execution.py000066400000000000000000000202561423713721100170560ustar00rootroot00000000000000import gi import logging import networkx as nx gi.require_version('Ufo', '0.0') from gi.repository import Ufo from PyQt5.QtCore import QObject, pyqtSignal from qtpynodeeditor import PortType from threading import Thread from tofu.flow.models import ARRAY_DATA_TYPE, UFO_DATA_TYPE, UfoTaskModel from tofu.flow.util import FlowError LOG = logging.getLogger(__name__) class UfoExecutor(QObject): """Class holding GPU resources and organizing UFO graph execution.""" number_of_inputs_changed = pyqtSignal(int) # Number of inputs has been determined processed_signal = pyqtSignal(int) # Image has been processed execution_started = pyqtSignal() # Graph execution started execution_finished = pyqtSignal() # Graph execution finished exception_occured = pyqtSignal(str) def __init__(self): super().__init__(parent=None) self._resources = Ufo.Resources() self._reset() # If True only log the exception and emit the signal but don't re-raise it in the executing # thread self.swallow_run_exceptions = False def _reset(self): self._aborted = False self._schedulers = [] self.num_generated = 0 def abort(self): LOG.debug('Execution aborted') try: self._aborted = True for scheduler in self._schedulers: scheduler.abort() finally: self.execution_finished.emit() def on_processed(self, ufo_task): self.processed_signal.emit(self.num_generated) self.num_generated += 1 def setup_ufo_graph(self, graph, gpu=None, region=None, signalling_model=None): ufo_graph = Ufo.TaskGraph() ufo_tasks = {} for source, dest, ports in graph.edges.data(): if hasattr(source, 'create_ufo_task') and hasattr(dest, 'create_ufo_task'): if dest not in ufo_tasks: ufo_tasks[dest] = dest.create_ufo_task(region=region) if source not in ufo_tasks: ufo_tasks[source] = source.create_ufo_task(region=region) ufo_graph.connect_nodes_full(ufo_tasks[source], ufo_tasks[dest], ports[PortType.input]) LOG.debug(f'{source.name}->{dest.name}@{ports[PortType.input]}') if source == signalling_model: ufo_tasks[source].connect('generated', self.on_processed) if gpu is not None: for task in ufo_tasks.values(): if task.uses_gpu(): task.set_proc_node(gpu) return ufo_graph def _run_ufo_graph(self, ufo_graph, use_fixed_scheduler): LOG.debug(f'Executing graph, fixed scheduler: {use_fixed_scheduler}') try: scheduler = Ufo.FixedScheduler() if use_fixed_scheduler else Ufo.Scheduler() self._schedulers.append(scheduler) scheduler.set_resources(self._resources) scheduler.run(ufo_graph) LOG.info(f'Execution time: {scheduler.props.time} s') except Exception as e: # Do not continue execution of other batches self._aborted = True LOG.error(e, exc_info=True) self.exception_occured.emit(str(e)) if not self.swallow_run_exceptions: raise e def check_graph(self, graph): """ Check that *graph* starts with an UfoTaskModel and ends with either that or an UfoModel but no UfoTaskModel successor exists (there can be only one UFO path in the graph). """ roots = [n for n in graph.nodes if graph.in_degree(n) == 0] leaves = [n for n in graph.nodes if graph.out_degree(n) == 0] for root in roots: for leave in leaves: for path in nx.simple_paths.all_simple_paths(graph, root, leave): if not isinstance(path[0], UfoTaskModel): raise FlowError('Flow must start with an UFO node') ufo_ended = False for (i, succ) in enumerate(path[1:]): model = path[i] edge_data = graph.get_edge_data(model, succ) if len(edge_data) > 1: # There cannot be multiple edges between nodes raise FlowError('Multiple edges not allowed but detected ' 'between {model} and {succ}') out_index = edge_data[0]['output'] # We don't need to check if input data type is ARRAY_DATA_TYPE because # UFO_DATA_TYPE cannot be connected to ARRAY_DATA_TYPE in the scene if ufo_ended: # From now on only non-UFO tasks are allowed if model.data_type['output'][out_index] != ARRAY_DATA_TYPE: raise FlowError('After a non-UFO node cannot come another UFO node') elif model.data_type['output'][out_index] != UFO_DATA_TYPE: # Output is non-UFO, UFO ends here ufo_ended = True def run(self, graph): self._reset() self.check_graph(graph) gpus = self._resources.get_gpu_nodes() num_inputs = -1 signalling_model = None for model in graph.nodes: if graph.in_degree(model) == 0: if 'number' in model: current = model['number'] if current > num_inputs: num_inputs = current signalling_model = model batches = [[(None, None)]] gpu_splitting_model = None gpu_splitting_models = get_gpu_splitting_models(graph) if len(gpu_splitting_models) > 1: # There cannot be multiple splitting models raise FlowError('Only one gpu splitting model is allowed') elif gpu_splitting_models: gpu_splitting_model = gpu_splitting_models[0] batches = gpu_splitting_model.split_gpu_work(self._resources.get_gpu_nodes()) for model in graph.nodes: # Reset internal model state if hasattr(model, 'reset_batches'): model.reset_batches() LOG.debug(f'{len(batches)} batches: {batches}') if signalling_model: self.number_of_inputs_changed.emit(len(batches) * num_inputs) LOG.debug(f'Number of inputs: {len(batches) * num_inputs}, defined ' f'by {signalling_model}') def execute_batches(): self.execution_started.emit() try: for (i, parallel_batch) in enumerate(batches): LOG.info(f'starting batch {i}: {parallel_batch}') threads = [] for gpu_index, region in parallel_batch: if self._aborted: break gpu = None if gpu_index is None else gpus[gpu_index] ufo_graph = self.setup_ufo_graph(graph, gpu=gpu, region=region, signalling_model=signalling_model) t = Thread(target=self._run_ufo_graph, args=(ufo_graph, len(gpu_splitting_models) > 0)) t.daemon = True threads.append(t) t.start() for t in threads: t.join() if self._aborted: break except Exception as e: LOG.error(e, exc_info=True) self.exception_occured.emit(str(e)) raise e finally: self.execution_finished.emit() gt = Thread(target=execute_batches) gt.daemon = True gt.start() def get_gpu_splitting_models(graph): gpu_splitting_models = [] for model in graph.nodes: if isinstance(model, UfoTaskModel) and model.can_split_gpu_work: gpu_splitting_models.append(model) return gpu_splitting_models tofu-0.12.0/tofu/flow/filedirdialog.py000066400000000000000000000013131423713721100176420ustar00rootroot00000000000000import os from PyQt5.QtWidgets import QFileDialog class FileDirDialog(QFileDialog): """ A workaround for being able to select both files and directories. Source: https://stackoverflow.com/questions/27520304/qfiledialog-that-accepts-a-single-file-or-a-single-directory """ def __init__(self, parent=None): super().__init__(parent=parent) self.setOption(QFileDialog.DontUseNativeDialog) self.setFileMode(QFileDialog.Directory) self.currentChanged.connect(self._selected) def _selected(self, name): if os.path.isdir(name): self.setFileMode(QFileDialog.Directory) else: self.setFileMode(QFileDialog.ExistingFile) tofu-0.12.0/tofu/flow/main.py000066400000000000000000000543411423713721100160010ustar00rootroot00000000000000import json import logging import os import pathlib import sys from PyQt5.QtCore import Qt, QObject, QPoint, pyqtSignal from PyQt5.QtWidgets import (QApplication, QFileDialog, QWidget, QVBoxLayout, QMenuBar, QMessageBox, QProgressBar, QMainWindow, QStyle) from qtpynodeeditor import DataModelRegistry, FlowView from xdg import xdg_data_home from tofu.flow.execution import UfoExecutor from tofu.flow.models import (BaseCompositeModel, get_composite_model_classes_from_json, get_composite_model_classes, get_ufo_model_classes, ImageViewerModel, UfoGeneralBackprojectModel, UfoMemoryOutModel, UfoOpenCLModel, UfoReadModel, UfoRetrievePhaseModel, UfoWriteModel) from tofu.flow.scene import UfoScene from tofu.flow.propertylinkswidget import PropertyLinks from tofu.flow.runslider import RunSlider from tofu.flow.util import FlowError LOG = logging.getLogger(__name__) class ApplicationWindow(QMainWindow): def __init__(self, ufo_scene): super().__init__() self.ufo_scene = ufo_scene self.property_links_widget = PropertyLinks(ufo_scene.node_model, ufo_scene.property_links_model, parent=self) self.run_slider = RunSlider(parent=self) self.executor = UfoExecutor() self.console = None self.run_slider_key = (None, None) self.last_dirs = {'scene': None, 'composite': None} self._creating_composite = False self._expanding_composite = False central_widget = QWidget() self.setCentralWidget(central_widget) main_layout = QVBoxLayout(central_widget) self.flow_view = FlowView(self.ufo_scene) self.progress_bar = QProgressBar() self.progress_bar.setMinimum(0) menu_bar = QMenuBar() flow_menu = menu_bar.addMenu('Flow') new_action = flow_menu.addAction("New") new_action.setShortcut('Ctrl+N') new_action.triggered.connect(self.on_new) save_action = flow_menu.addAction("Save") save_action.setShortcut('Ctrl+S') save_action.triggered.connect(self.on_save) load_action = flow_menu.addAction("Open") load_action.setShortcut('Ctrl+O') load_action.triggered.connect(self.on_open) self.run_action = flow_menu.addAction(self.style().standardIcon(QStyle.SP_MediaPlay), 'Run') self.run_action.setShortcut('Ctrl+R') self.run_action.triggered.connect(self.on_run) abort_action = flow_menu.addAction(self.style().standardIcon(QStyle.SP_MediaStop), 'Abort') abort_action.setShortcut('Ctrl+Shift+X') abort_action.triggered.connect(self.executor.abort) exit_action = flow_menu.addAction('Exit') exit_action.setShortcut('Ctrl+Q') exit_action.triggered.connect(self.close) # Nodes submenu selection_menu = menu_bar.addMenu('Nodes') selection_menu.setToolTipsVisible(True) selection_menu.aboutToShow.connect(self.on_selection_menu_about_to_show) self.skip_action = selection_menu.addAction('Skip Toggle') self.skip_action.setShortcut('S') self.skip_action.triggered.connect(self.ufo_scene.skip_nodes) auto_fill_action = selection_menu.addAction('Auto fill') auto_fill_action.triggered.connect(self.ufo_scene.auto_fill) copy_action = selection_menu.addAction("Duplicate") copy_action.setShortcut('Ctrl+Shift+D') copy_action.triggered.connect(self.ufo_scene.copy_nodes) # Composite create_composite_action = selection_menu.addAction("Create Composite") create_composite_action.setShortcut('Ctrl+Shift+C') create_composite_action.triggered.connect(self.on_create_composite) import_composites_action = selection_menu.addAction("Import Composites") import_composites_action.setToolTip('Import one or more composite nodes ' 'from a file or files') import_composites_action.setShortcut('Ctrl+I') import_composites_action.triggered.connect(self.on_import_composites) self.export_composite_action = selection_menu.addAction("Export Composite") self.export_composite_action.triggered.connect(self.on_export_composite) self.edit_composite_action = selection_menu.addAction("Edit Composite") self.edit_composite_action.triggered.connect(self.on_edit_composite) self.expand_composite_action = selection_menu.addAction("Expand Composite") self.expand_composite_action.setShortcut('Ctrl+Shift+E') self.expand_composite_action.triggered.connect(self.on_expand_composite) view_menu = menu_bar.addMenu('View') reset_view_action = view_menu.addAction("Reset Zoom") reset_view_action.setShortcut('Ctrl+0') reset_view_action.triggered.connect(self.on_reset_view) property_links_action = view_menu.addAction("Link Properties") property_links_action.setShortcut('Ctrl+L') property_links_action.triggered.connect(self.on_property_links_action) console_action = view_menu.addAction("Open Python Console") console_action.setShortcut('Ctrl+Shift+P') console_action.triggered.connect(self.on_console_action) run_slider_action = view_menu.addAction("Run Slider") run_slider_action.setShortcut('Ctrl+Shift+S') run_slider_action.triggered.connect(self.on_run_slider_action) self.fix_run_slider = view_menu.addAction("Fix Run Slider") self.fix_run_slider.setCheckable(True) self.fix_run_slider.setShortcut('Ctrl+Alt+Shift+S') main_layout.addWidget(menu_bar) main_layout.addWidget(self.flow_view) main_layout.addWidget(self.progress_bar) main_layout.setContentsMargins(0, 0, 0, 0) main_layout.setSpacing(0) self.resize(1280, 1000) # Signals self.executor.exception_occured.connect(self.on_exception_occured) self.executor.execution_finished.connect(self.on_execution_finished) self.executor.number_of_inputs_changed.connect(self.on_number_of_inputs_changed) self.executor.processed_signal.connect(self.on_processed) self.ufo_scene.node_deleted.connect(self.on_node_deleted) self.ufo_scene.nodes_duplicated.connect(self.on_nodes_duplicated) self.ufo_scene.item_focus_in.connect(self.on_item_focus_in) self.run_slider.value_changed.connect(self.on_run_slider_value_changed) self.setWindowTitle('tofu flow') def on_save(self): if self.last_dirs['scene']: path = self.last_dirs['scene'] else: path = os.path.join(xdg_data_home(), 'tofu', 'flows') if not os.path.exists(path): os.makedirs(path) file_name, _ = QFileDialog.getSaveFileName(self, "Select File Name", str(path), "Flow Scene Files (*.flow)") if file_name: self.last_dirs['scene'] = os.path.dirname(file_name) self.ufo_scene.save(file_name) def on_new(self): self.run_slider.reset() self.ufo_scene.clear_scene() self.setWindowTitle('tofu flow') def on_open(self): if self.last_dirs['scene']: path = self.last_dirs['scene'] else: path = os.path.join(xdg_data_home(), 'tofu', 'flows') if not os.path.exists(path): path = pathlib.Path.home() file_name, _ = QFileDialog.getOpenFileName(self, "Open Flow Scene", str(path), "Flow Scene Files (*.flow)") if file_name: self.last_dirs['scene'] = os.path.dirname(file_name) self.ufo_scene.load(file_name) self.run_slider.reset() self.setWindowTitle(file_name) def on_exception_occured(self, text): msg = QMessageBox(parent=self) msg.setIcon(QMessageBox.Critical) msg.setText(text) msg.setWindowTitle("Error") msg.exec_() def on_number_of_inputs_changed(self, value): self.progress_bar.setMaximum(value) def on_processed(self, value): self.progress_bar.setValue(value + 1) def on_node_deleted(self, node): slider_model, prop_name = self.run_slider_key if slider_model: if (isinstance(node.model, BaseCompositeModel) and node.model.is_model_inside(slider_model) and not (self._expanding_composite or self._creating_composite)): self.run_slider.reset() self.run_slider_key = (None, None) elif node.model == slider_model and not self._creating_composite: self.run_slider.reset() self.run_slider_key = (None, None) def on_nodes_duplicated(self, selected_nodes, new_nodes): min_y = float('inf') y_1 = float('-inf') for node in selected_nodes: height = node.model.embedded_widget().height() y = node.graphics_object.y() if y < min_y: min_y = y if y + height > y_1: y_1 = y + height for node in selected_nodes: dy = node.graphics_object.y() - min_y new_pos = QPoint(int(node.graphics_object.x()), int(dy + y_1 + 100)) new_nodes[node].graphics_object.setPos(new_pos) def on_item_focus_in(self, item, prop_name, caption, model): if not self.fix_run_slider.isChecked() or not self.run_slider.view_item: if self.run_slider.setup(item): self.run_slider_key = (model, prop_name) self.run_slider.setWindowTitle(f'{caption}->{prop_name}') def on_selection_menu_about_to_show(self): composites = False num_selected = len(self.ufo_scene.selected_nodes()) for node in self.ufo_scene.selected_nodes(): if isinstance(node.model, BaseCompositeModel): composites = True break self.edit_composite_action.setEnabled(num_selected == 1 and composites) self.export_composite_action.setEnabled(num_selected == 1 and composites) self.expand_composite_action.setEnabled(composites) self.skip_action.setEnabled(self.ufo_scene.selected_nodes() != []) def on_edit_composite(self): if self.ufo_scene.is_selected_one_composite(): # Check again in case this was invoked by the keyboard shortcut node = self.ufo_scene.selected_nodes()[0] node.model.edit_in_window(self) def on_create_composite(self): self._creating_composite = True try: path = None prop_name = self.run_slider_key[1] if self.run_slider_key[0]: for node in self.ufo_scene.selected_nodes(): if isinstance(node.model, BaseCompositeModel): if node.model.is_model_inside(self.run_slider_key[0]): path = node.model.get_path_from_model(self.run_slider_key[0]) elif node.model == self.run_slider_key[0]: path = [self.run_slider_key[0]] composite_model = self.ufo_scene.create_composite().model if path: str_path = [model.caption for model in path] new_model = composite_model.get_model_from_path(str_path) new_view_item = new_model.get_view_item(prop_name) # Do not make complete setup, that would reset limits, just update the view item self.run_slider.view_item = new_view_item self.run_slider_key = (new_model, prop_name) title = '->'.join([composite_model.caption] + str_path + [prop_name]) self.run_slider.setWindowTitle(title) finally: self._creating_composite = False def on_expand_composite(self): self._expanding_composite = True try: slider_model, prop_name = self.run_slider_key for node in self.ufo_scene.selected_nodes(): if isinstance(node.model, BaseCompositeModel): if slider_model: str_path = None if node.model.is_model_inside(slider_model): str_path = [model.caption for model in node.model.get_path_from_model(slider_model)] new_nodes = self.ufo_scene.expand_composite(node)[0] # Pass the new node to the run slider if it was contained in this composite if slider_model and str_path: if slider_model.caption in new_nodes: # runslider linked to a simple node after expanstion slider_model = new_nodes[slider_model.caption].model self.run_slider_key = (slider_model, prop_name) new_view_item = slider_model.get_view_item(prop_name) # Do not make complete setup, that would reset limits, just update the # view item self.run_slider.view_item = new_view_item self.run_slider.setWindowTitle(f'{slider_model.caption}->{prop_name}') else: # runslider linked to another composite node (nesting) after expanstion for node in new_nodes.values(): if isinstance(node.model, BaseCompositeModel): if node.model.contains_path(str_path[2:]): new_model = node.model.get_model_from_path(str_path[2:]) self.run_slider_key = (new_model, prop_name) new_view_item = new_model.get_view_item(prop_name) # Do not make complete setup, that would reset limits, just # update the view item self.run_slider.view_item = new_view_item title = '->'.join(str_path[1:] + [prop_name]) self.run_slider.setWindowTitle(title) self.run_slider_key = (new_model, prop_name) break finally: self._expanding_composite = False def on_import_composites(self): if self.last_dirs['composite']: path = self.last_dirs['composite'] else: path = os.path.join(xdg_data_home(), 'tofu', 'flows', 'composites') if not os.path.exists(path): path = pathlib.Path.home() file_names, _ = QFileDialog.getOpenFileNames(self, "Select File Names", str(path), "Composite Model Files (*.cm)") if not file_names: return self.last_dirs['composite'] = os.path.dirname(file_names[0]) overwriting = {} for file_name in file_names: LOG.debug(f'Loading composite from {file_name}') with open(file_name, 'r') as f: state = json.load(f) for model in get_composite_model_classes_from_json(state): if model.name in self.ufo_scene.registry.registered_model_creators(): overwriting[model.name] = os.path.basename(file_name) self.ufo_scene.registry.register_model(model, category='Composite', registry=self.ufo_scene.registry) if overwriting: msg = QMessageBox(parent=self) msg.setIcon(QMessageBox.Warning) msg.setText('Composite nodes with same names detected. Files from which ' 'the nodes have been loaded are listed in details.') msg.setDetailedText('\n'.join([f'Node name "{name}" from file "{file_name}"' for (name, file_name) in overwriting.items()])) msg.setWindowTitle('Warning') msg.exec_() def export_composite(self, node, file_name): state = node.model.save() with open(file_name, 'w') as f: json.dump(state, f, indent=4) def on_export_composite(self): if not self.ufo_scene.is_selected_one_composite(): # Check again in case this was invoked by the keyboard shortcut return if self.last_dirs['composite']: path = self.last_dirs['composite'] else: path = os.path.join(xdg_data_home(), 'tofu', 'flows', 'composites') if not os.path.exists(path): os.makedirs(path) file_name, _ = QFileDialog.getSaveFileName(self, "Select File Name", str(path), "Composite Model Files (*.cm)") if file_name: self.last_dirs['composite'] = os.path.dirname(file_name) if not file_name.endswith('.cm'): file_name += '.cm' self.export_composite(self.ufo_scene.selected_nodes()[0], file_name) def on_reset_view(self): for view in self.ufo_scene.views(): transform = view.transform() transform.reset() view.setTransform(transform) def on_property_links_action(self): self.property_links_widget.show() # Make sure it goes to the front if it is currently burried under other windows self.property_links_widget.raise_() def on_console_action(self): if self.console: self.console.show() return try: from pyqtconsole.console import PythonConsole from pyqtconsole.highlighter import format self.console = PythonConsole(formats={ 'keyword': format('darkBlue', 'bold') }) self.console.setWindowFlag(Qt.SubWindow, True) self.console.ctrl_d_exits_console(True) self.console.push_local_ns('scene', self.ufo_scene) self.console.resize(640, 480) self.console.show() self.console.eval_queued() except ImportError as e: LOG.error(e, exc_info=True) self.on_exception_occured(str(e)) def on_run_slider_action(self): if not self.run_slider.view_item: msg = QMessageBox(parent=self) msg.setIcon(QMessageBox.Information) msg.setText('Click on an input field in the flow to connect the slider') msg.exec_() else: self.run_slider.show() # Make sure it goes to the front if it is currently burried under other windows self.run_slider.raise_() def on_run_slider_value_changed(self, value): if self.run_action.isEnabled(): self.on_run() def on_run(self): graphs = self.ufo_scene.get_simple_node_graphs() if len(graphs) != 1: raise FlowError('Scene must contain one fully connected graph') if not self.ufo_scene.is_fully_connected(): raise FlowError('Not all node ports are connected') self.executor.run(graphs[0]) self.run_action.setEnabled(False) self.ufo_scene.set_enabled(False) def on_execution_finished(self): self.progress_bar.reset() self.run_action.setEnabled(True) self.ufo_scene.set_enabled(True) class GlobalExceptionHandler(QObject): """ Intercept exceptions, log them and inform user if they are UI-related. Emit a signal when the error message should be shown to the user so that e.g. a message can be shown in the main thread. """ exception_occured = pyqtSignal(str) def excepthook(self, exc_type, exc_value, exc_traceback): LOG.error(exc_value, exc_info=(exc_type, exc_value, exc_traceback)) if issubclass(exc_type, FlowError): self.exception_occured.emit(str(exc_value)) def get_filled_registry(): registry = DataModelRegistry() for model in get_ufo_model_classes(): category = 'Processing' if model.num_ports['input'] == 0: category = 'Input' if model.num_ports['output'] == 0: category = 'Output' registry.register_model(model, category=category, scrollable=True) registry.register_model(UfoGeneralBackprojectModel, category='Processing') registry.register_model(UfoOpenCLModel, category='Processing') registry.register_model(UfoRetrievePhaseModel, category='Processing') registry.register_model(UfoMemoryOutModel, category='Data') registry.register_model(ImageViewerModel, category='Output') registry.register_model(UfoWriteModel, category='Output') registry.register_model(UfoReadModel, category='Input') for models in get_composite_model_classes(): for model in models: if model.name not in registry.registered_model_creators(): registry.register_model(model, category='Composite', registry=registry) return registry def main(): app = QApplication(sys.argv) scene = UfoScene(registry=get_filled_registry()) main_window = ApplicationWindow(scene) # Exception interception exception_handler = GlobalExceptionHandler() exception_handler.exception_occured.connect(main_window.on_exception_occured) # Do not use threading.excepthook because it needs at least python 3.8., i.e. all exceptions in # threads have to be handled properly (logged, signal emitted so that a message can be displayed # in the main thread to the user, see tofu.flow.execution for example). sys.excepthook = exception_handler.excepthook main_window.show() sys.exit(app.exec_()) if __name__ == '__main__': main() tofu-0.12.0/tofu/flow/models.py000066400000000000000000001726251423713721100163460ustar00rootroot00000000000000""" All classes needed for :class:`qtpynodeeditor.NodeDataModel` implementation of UFO and composite tasks. """ import gi import glob import json import logging import networkx as nx import numpy as np import pkg_resources import os import re gi.require_version('Ufo', '0.0') from gi.repository import Ufo from PyQt5 import QtCore from PyQt5.QtCore import QObject, Qt, pyqtSignal from PyQt5.QtGui import QDoubleValidator, QValidator from PyQt5.QtWidgets import (QCheckBox, QComboBox, QGroupBox, QInputDialog, QLabel, QLineEdit, QScrollArea, QWidget, QFileDialog, QFormLayout, QVBoxLayout, QMenu) from qtpynodeeditor import (NodeData, NodeDataModel, NodeDataType, FlowScene, FlowView, Port, PortType, opposite_port) from threading import Lock from tofu.flow.util import (CompositeConnection, FlowError, get_config_key, MODEL_ROLE, NODE_ROLE, PROPERTY_ROLE, saved_kwargs) from tofu.flow.filedirdialog import FileDirDialog LOG = logging.getLogger(__name__) UFO_PLUGIN_MANAGER = Ufo.PluginManager() UFO_DATA_TYPE = NodeDataType(id="UfoBuffer", name=None) ARRAY_DATA_TYPE = NodeDataType(id="NumpyArray", name=None) class UfoIntValidator(QValidator): """Combined int and unsigned int validator.""" def __init__(self, minimum, maximum, parent=None): super().__init__(parent=parent) self.minimum = minimum self.maximum = maximum def bottom(self): return self.minimum def top(self): return self.maximum def validate(self, input_str, pos): try: if self.minimum <= int(input_str) <= self.maximum: result = (QValidator.Acceptable, input_str, pos) else: result = (QValidator.Intermediate, input_str, pos) except ValueError: if not input_str or input_str == '-' and self.minimum < 0: result = (QValidator.Intermediate, input_str, pos) else: result = (QValidator.Invalid, input_str, pos) return result class UfoRangeValidator(QValidator): """ Range separated by comma validator. *num_items* specifies how many numbers must be in the string. *is_float* specifies if the numbers are floating point (integer or unsigned integer otherwise). """ def __init__(self, num_items=None, is_float=True, parent=None): super().__init__(parent=parent) self.num_items = num_items self.is_float = is_float def validate(self, input_str, pos): float_regexp = r'[+-]|[+-]?(\d+(\.\d*)?|\.\d*)([eE][+-]?\d*)?' numbers = input_str.split(',') intermediate = False if self.num_items is not None and len(numbers) > self.num_items: # Incorrect number of items return (QValidator.Invalid, input_str, pos) for (i, number) in enumerate(numbers): number = number.lower().strip() if ('e' in number or '.' in number) and not self.is_float: # Integer expected return (QValidator.Invalid, input_str, pos) if self.is_float: try: float(number) except: if (not number or re.fullmatch(float_regexp, number)): # Partial floating point number (e.g. ends with "e") intermediate = True continue else: return (QValidator.Invalid, input_str, pos) else: try: int(number) except: if not number or number == '-': intermediate = True continue else: return (QValidator.Invalid, input_str, pos) if intermediate or (self.num_items is not None and len(numbers) < self.num_items): # Not enough arguments received or some numbers are incomplete return (QValidator.Intermediate, input_str, pos) return (QValidator.Acceptable, input_str, pos) class ViewItem(QObject): property_changed = pyqtSignal(QObject) def __init__(self, widget, default_value=None, tooltip=''): super().__init__(parent=None) self.widget = widget self.focus_info = False if tooltip: self.widget.setToolTip(tooltip) if default_value is not None: self.set(default_value) def on_changed(self, *args): """ Only user interaction must emit signals in the descendants. Signal is emitted only if the user input is valid. """ try: self.get() self.property_changed.emit(self) except: LOG.debug(f'{self}: invalid input') def get(self): ... def set(self, value): ... class CheckBoxViewItem(ViewItem): def __init__(self, checked=False, tooltip=''): widget = QCheckBox() super().__init__(widget, default_value=checked, tooltip=tooltip) widget.clicked.connect(self.on_changed) def get(self): return self.widget.isChecked() def set(self, value): self.widget.setChecked(value) class ComboBoxViewItem(ViewItem): def __init__(self, items, default_value=None, tooltip=''): widget = QComboBox() for item in items: widget.addItem(item) super().__init__(widget, default_value=default_value, tooltip=tooltip) widget.activated.connect(self.on_changed) def get(self): return self.widget.currentText() def set(self, value): self.widget.setCurrentText(value) class FocusInterceptQLineEdit(QLineEdit): focus_in = pyqtSignal(QObject) def focusInEvent(self, event): self.focus_in.emit(self) return super().focusInEvent(event) class QLineEditViewItem(ViewItem): focus_in = pyqtSignal(QObject) def __init__(self, default_value=None, tooltip='', intercept_focus=False): if intercept_focus: widget = FocusInterceptQLineEdit() widget.focus_in.connect(self.on_focus_in) else: widget = QLineEdit() super().__init__(widget, default_value=default_value, tooltip=tooltip) if intercept_focus: self.focus_info = True widget.textEdited.connect(self.on_changed) def on_focus_in(self, widget): self.focus_in.emit(self) def get(self): return self.widget.text() def set(self, value): self.widget.setText(str(value)) class NumberQLineEditViewItem(QLineEditViewItem): def __init__(self, minimum, maximum, default_value=None, tooltip=''): if default_value < minimum or default_value > maximum: raise ValueError(f'default value {default_value} not in limits [{minimum}, {maximum}]') tooltip += ' (range: {} - {})'.format(minimum, maximum) super().__init__(default_value=default_value, tooltip=tooltip, intercept_focus=True) validator = QDoubleValidator(float(minimum), float(maximum), 100) self.widget.setValidator(validator) def get(self): return float(super().get()) class IntQLineEditViewItem(QLineEditViewItem): def __init__(self, minimum, maximum, default_value=None, tooltip=''): if default_value < minimum or default_value > maximum: raise ValueError(f'default value {default_value} not in limits [{minimum}, {maximum}]') tooltip += ' (range: {} - {})'.format(minimum, maximum) super().__init__(default_value=default_value, tooltip=tooltip, intercept_focus=True) validator = UfoIntValidator(minimum, maximum) self.widget.setValidator(validator) def get(self): return int(super().get()) class RangeQLineEditViewItem(QLineEditViewItem): def __init__(self, default_value='', tooltip='', num_items=None, is_float=True): super().__init__(default_value=default_value, tooltip=tooltip, intercept_focus=True) validator = UfoRangeValidator(num_items=num_items, is_float=is_float) self.widget.setValidator(validator) def set(self, values): text = ','.join([str(value) for value in values]) if values else '' self.widget.setText(text) def get(self): text = super().get() if text: values = [float(num) for num in text.split(',')] else: values = [] return values def get_ufo_qline_edit_item(glib_prop, default_value, range_num_items=None, range_is_float=True): if glib_prop.value_type.name == 'GValueArray': item = RangeQLineEditViewItem(tooltip=glib_prop.blurb, default_value=default_value, num_items=range_num_items, is_float=range_is_float) elif glib_prop.value_type.name in ['gdouble', 'gfloat']: item = NumberQLineEditViewItem(glib_prop.minimum, glib_prop.maximum, default_value=default_value, tooltip=glib_prop.blurb) elif hasattr(glib_prop, 'minimum') and hasattr(glib_prop, 'maximum'): item = IntQLineEditViewItem(glib_prop.minimum, glib_prop.maximum, default_value=default_value, tooltip=glib_prop.blurb) else: item = QLineEditViewItem(default_value=str(default_value), tooltip=glib_prop.blurb) return item class PropertyViewRecord: """Attribute-access to a view's item.""" def __init__(self, view_item, label, visible): self.view_item = view_item self.label = label self.visible = visible def __str__(self): return repr(self) def __repr__(self): fmt = 'PropertyViewRecord(widget={}, visible={})' return fmt.format(self.view_item.widget, self.visible) class MultiPropertyViewRecord: """Attribute-access to a multiple property view's item.""" def __init__(self, model, widget, visible): self.model = model self.widget = widget self.visible = visible def __str__(self): return repr(self) def __repr__(self): fmt = 'MultiPropertyViewRecord(model={}, widget={}, visible={})' return fmt.format(self.model, self.widget, self.visible) class PropertyView(QWidget): property_changed = pyqtSignal(str, object) item_focus_in = pyqtSignal(ViewItem, str) def __init__(self, properties=None, parent=None, scrollable=True): super().__init__(parent=parent) form_layout = QFormLayout() form_layout.setVerticalSpacing(0) self._properties = {} if properties: for (name, (item, active)) in properties.items(): if name in self._properties: raise ValueError("Item '{}' already exists".format(name)) # Set the parent properly, so that set_property_visible won't try to show the item # widget and the label in their own windows before the view is shown item.widget.setParent(self) label = QLabel(name, parent=self) form_layout.addRow(label, item.widget) self._properties[name] = PropertyViewRecord(item, label, active) self.set_property_visible(name, active) item.property_changed.connect(self.on_property_changed) if item.focus_info: item.focus_in.connect(self.on_item_focus_in) if scrollable: widget = QWidget() widget.setLayout(form_layout) scroll = QScrollArea() scroll.setWidget(widget) scroll.setWidgetResizable(True) main_layout = QVBoxLayout() main_layout.addWidget(scroll) self.setLayout(main_layout) else: self.setLayout(form_layout) @property def property_names(self): return self._properties.keys() def get_property(self, name): return self._properties[name].view_item.get() def set_property(self, name, value): return self._properties[name].view_item.set(value) def get_record(self, name): return self._properties[name] def on_property_changed(self, item): # Get item's name for (name, record) in self._properties.items(): if item == record.view_item: break self.property_changed.emit(name, item.get()) def on_item_focus_in(self, view_item): for (name, it) in self._properties.items(): if it.view_item.widget == view_item.widget: self.item_focus_in.emit(view_item, name) break def is_property_visible(self, name): return self._properties[name].visible def set_property_visible(self, name, visible): self._properties[name].view_item.widget.setVisible(visible) self._properties[name].label.setVisible(visible) self._properties[name].visible = visible def restore_properties(self, values): for prop in self._properties: if prop not in values: LOG.debug(f'Property {prop} not stored, using default') continue value, visible = values[prop] self.set_property(prop, value) self.set_property_visible(prop, visible) def export_properties(self): values = {} for prop in self._properties: values[prop] = [self.get_property(prop), self.is_property_visible(prop)] return values def contextMenuEvent(self, event): contextMenu = QMenu(self) actions = {} for name in list(self._properties.keys()): action = contextMenu.addAction(name) action.setCheckable(True) action.setChecked(self._properties[name].visible) actions[action] = name contextMenu.addSeparator() show_all_action = contextMenu.addAction('Show All') hide_all_action = contextMenu.addAction('Hide All') action = contextMenu.exec_(self.mapToGlobal(event.pos())) if action: if action in actions: name = actions[action] checked = action.isChecked() self.set_property_visible(name, checked) elif action == show_all_action: for name in self._properties.keys(): self.set_property_visible(name, True) elif action == hide_all_action: for name in self._properties.keys(): self.set_property_visible(name, False) class MultiPropertyView(QWidget): def __init__(self, groups, parent=None): super().__init__(parent=parent) self._group_box_layout = QVBoxLayout() main_layout = QVBoxLayout() widget = QWidget() widget.setLayout(self._group_box_layout) scroll = QScrollArea() scroll.setWidget(widget) scroll.setWidgetResizable(True) self.setLayout(main_layout) main_layout.addWidget(scroll) self._groups = {} for (model, visible) in groups.items(): if isinstance(model, PropertyModel): model_widget = QGroupBox(model.caption) layout = QVBoxLayout() model_widget.setLayout(layout) layout.addWidget(model.embedded_widget()) else: model_widget = QLabel(model.caption, parent=self) record = MultiPropertyViewRecord(model, model_widget, visible) self._groups[model.caption] = record self._group_box_layout.addWidget(model_widget) self.set_group_visible(model.caption, visible) def __getitem__(self, key): return self._groups[key].model def __contains__(self, key): return key in self._groups def __iter__(self): return iter(self._groups) def export_groups(self): values = {} for name in self._groups: state = self._groups[name].model.save() values[name] = {'model': state, 'visible': self._groups[name].visible} return values def restore_groups(self, values): for name in values: self[name].restore(values[name]['model']) self.set_group_visible(name, values[name]['visible']) def set_group_visible(self, name, visible): self._groups[name].widget.setVisible(visible) self._groups[name].visible = visible def is_group_visible(self, name): return self._groups[name].visible def contextMenuEvent(self, event): contextMenu = QMenu(self) actions = {} for name in list(self._groups.keys()): action = contextMenu.addAction(name) action.setCheckable(True) action.setChecked(self._groups[name].visible) actions[action] = name contextMenu.addSeparator() show_all_action = contextMenu.addAction('Show All') hide_all_action = contextMenu.addAction('Hide All') action = contextMenu.exec_(self.mapToGlobal(event.pos())) if action: if action in actions: name = actions[action] checked = action.isChecked() self.set_group_visible(name, checked) elif action == show_all_action: for name in self._groups.keys(): self.set_group_visible(name, True) elif action == hide_all_action: for name in self._groups.keys(): self.set_group_visible(name, False) class UfoModel(NodeDataModel): """The root parent of all other models in tofu flow.""" data_type = UFO_DATA_TYPE item_focus_in = pyqtSignal(QObject, str, str, NodeDataModel) def __init__(self, style=None, parent=None): super().__init__(style=style, parent=parent) # This is the caption model wants to have when it's instantiated, however, it might # get a different caption from the scene because the captions must be unique within self.base_caption = self.caption self.skip = False def restore(self, state, restore_caption=False): if restore_caption: self.caption = state.get('caption', self.caption) def save(self): return {'caption': self.caption} def double_clicked(self, parent): ... def __repr__(self): return f'UfoModel({self.caption})' def __str__(self): return repr(self) class PropertyModel(UfoModel): property_changed = pyqtSignal(UfoModel, str, object) def __init__(self, style=None, parent=None, scrollable=True): """*properties* is a dictionary of name: ViewItem items.""" super().__init__(style=style, parent=parent) properties = self.make_properties() if properties: self.properties = list(properties.keys()) self._view = PropertyView(properties=properties, scrollable=scrollable) self._view.property_changed.connect(self.on_property_changed) self._view.item_focus_in.connect(self.on_item_focus_in) else: self.properties = [] self._view = None def __getitem__(self, key): return self._view.get_property(key) def __setitem__(self, key, value): return self._view.set_property(key, value) def __contains__(self, key): return key in self.properties def __iter__(self): return iter(self.properties) def get_view_item(self, name): return self._view.get_record(name).view_item def on_property_changed(self, name, value): self.property_changed.emit(self, name, value) def on_item_focus_in(self, item, name): self.item_focus_in.emit(item, name, self.caption, self) def make_properties(self): """*properties* is a dictionary of name: ViewItem items.""" return {} def copy_properties(self): properties = self.make_properties() for (name, (item, active)) in properties.items(): item.set(self[name]) properties[name][-1] = self._view.is_property_visible(name) return properties def auto_fill(self): """Automatically fill properties (e.g. number of files, etc.)""" ... def resizable(self): return True def embedded_widget(self) -> QWidget: return self._view if self._view else None def restore(self, state, restore_caption=True): self._view.restore_properties(state['properties']) super().restore(state, restore_caption=restore_caption) def save(self): state = super().save() state['properties'] = self._view.export_properties() return state class UfoTaskModel(PropertyModel): caption_visible = True def __init__(self, task_name, style=None, parent=None, scrollable=True): self._task_name = task_name self.caption = ' '.join([item[0].upper() + item[1:] for item in self.name.split('_')]) self.needs_fixed_scheduler = False self.can_split_gpu_work = False super().__init__(style=style, parent=parent, scrollable=scrollable) def make_properties(self): hidden_properties = get_config_key('models', self._task_name, 'hidden-properties') range_properties = get_config_key('models', self._task_name, 'range-properties', default={}) properties = {} ufo_task = UFO_PLUGIN_MANAGER.get_task(self._task_name) for prop in ufo_task.list_properties(): if prop.name == 'num-processed': continue default_value = getattr(ufo_task.props, prop.name) if prop.value_type.name == 'gboolean': item = CheckBoxViewItem(checked=default_value, tooltip=prop.blurb) elif hasattr(prop, 'enum_class'): items = [name.value_nick for name in default_value.__enum_values__.values()] item = ComboBoxViewItem(items, default_value=default_value.value_nick, tooltip=prop.blurb) else: range_num_items, range_is_float = range_properties.get(prop.name, (None, True)) item = get_ufo_qline_edit_item(prop, default_value=default_value, range_num_items=range_num_items, range_is_float=range_is_float) visible = True if hidden_properties and prop.name in hidden_properties: visible = False properties[prop.name] = [item, visible] return properties def create_ufo_task(self, region=None): if self.expects_multiple_inputs and region is None: raise UfoModelError(f'{self.caption} expects multiple inputs ' 'but there is no node with such capability in the flow') ufo_task = UFO_PLUGIN_MANAGER.get_task(self._task_name) self._setup_ufo_task(ufo_task, region=region) return ufo_task def _setup_ufo_task(self, ufo_task, region=None): for prop in self: setattr(ufo_task.props, prop, self[prop]) def reset_batches(self): """ In case the model can process batches and has internal state depending on them, this is where it can be re-set. """ pass @property def uses_gpu(self): return UFO_PLUGIN_MANAGER.get_task(self._task_name).uses_gpu() @property def expects_multiple_inputs(self): return False def get_ufo_model_class(ufo_task_name): # Use this to determine inputs and outputs but create a new object in the constructor in order # to enable multiple instances having different parameter values _ufo_task = UFO_PLUGIN_MANAGER.get_task(ufo_task_name) ufo_task_num_inputs = _ufo_task.get_num_inputs() ufo_task_num_outputs = int(_ufo_task.get_mode() & Ufo.TaskMode.SINK == 0) class UfoAutoModel(UfoTaskModel): name = ufo_task_name.replace('-', '_') def __init__(self, style=None, parent=None, scrollable=True): self.num_ports = {PortType.input: ufo_task_num_inputs, PortType.output: ufo_task_num_outputs} self.data_type = {} self.port_caption = {} self.port_caption_visible = {} for port_type in (PortType.input, PortType.output): self.data_type[port_type] = {} self.port_caption[port_type] = {} self.port_caption_visible[port_type] = {} for i in range(self.num_ports[port_type]): port_captions = get_config_key('models', ufo_task_name, 'port-captions') if port_captions: port_caption = port_captions[port_type][str(i)] port_caption_visible = True if port_caption else False else: port_caption = '' port_caption_visible = False self.data_type[port_type][i] = UFO_DATA_TYPE self.port_caption[port_type][i] = port_caption self.port_caption_visible[port_type][i] = port_caption_visible self.ufo_task = None super().__init__(ufo_task_name, style=style, parent=parent, scrollable=scrollable) return UfoAutoModel class BaseCompositeModel(UfoModel): # Move functionality which can go here from CompositeModel here data_type = UFO_DATA_TYPE def __init__(self, models, connections, links=None, registry=None, style=None, parent=None): if registry is None: # This has to be keyword argument because of the qtpynodeeditor's node creation # mechanism, but the argument is actually required raise AttributeError('registry must be provided') super().__init__(style=style, parent=parent) # Nodes in the edit pop-up window self.window_parent = None self._property_links_model = None self._links = [] if links is None else links self._slave_property_links = [] self._window_nodes = {} self._other_scene = None self._other_view = None self.num_ports = {PortType.input: 0, PortType.output: 0} self.data_type = {PortType.input: {}, PortType.output: {}} self.port_caption = {PortType.input: {}, PortType.output: {}} self.port_caption_visible = {PortType.input: {}, PortType.output: {}} groups = {} self._registry = registry self._models = {} # Internal connections self._connections = connections # Composite port to subnode port mapping self._inside_ports = {} # Subnode port to composite port mapping self._outside_ports = {} for (name, state, visible, position) in models: # Don't use the deafault registry creation because embedded PropertyModel must have # scrollable set to False cls, orig_kwargs = registry.registered_model_creators()[name] # Don't mess with the original dictionary kwargs = {orig_key: orig_value for (orig_key, orig_value) in orig_kwargs.items()} if issubclass(cls, PropertyModel): kwargs['scrollable'] = False if 'num-inputs' in state: kwargs['num_inputs'] = state['num-inputs'] model = cls(**kwargs) model.restore(state) self._models[model] = position groups[model] = visible model.item_focus_in.connect(self.on_item_focus_in) for port_type in ['input', 'output']: for index in range(model.num_ports[port_type]): side = (model.caption, port_type, index) if not any([conn.contains(*side) for conn in connections]): i = self.num_ports[port_type] self.data_type[port_type][i] = UFO_DATA_TYPE port_caption = model.caption if model.port_caption[port_type][index]: port_caption += ':' + model.port_caption[port_type][index] self.port_caption[port_type][i] = port_caption self.port_caption_visible[port_type][i] = True self._inside_ports[(port_type, i)] = (model, port_type, index) self._outside_ports[side] = (port_type, i) self.num_ports[port_type] += 1 self._view = MultiPropertyView(groups) def __getitem__(self, key): return self._view[key] def __contains__(self, key): return key in self._view def __iter__(self): return iter(self._view) def __repr__(self): return f'Composite(caption={self.caption}, models={sorted(list(iter(self._view)))})' def __str__(self): return repr(self) def get_outside_port(self, unique_name, port_type, port_index): return self._outside_ports[(unique_name, port_type, port_index)] def get_model_and_port_index(self, port_type, port_index): model, spt, index = self._inside_ports[(port_type, port_index)] return (model, index) def embedded_widget(self) -> QWidget: return self._view if self._view else None def resizable(self): return True def on_item_focus_in(self, item, name, caption, model): self.item_focus_in.emit(item, name, self.caption + '->' + caption, model) @property def is_editing(self): """Is wubwindow open.""" return self._window_nodes != {} @property def property_links_model(self): return self._property_links_model @property_links_model.setter def property_links_model(self, plm): self._property_links_model = plm for model in self._models: if isinstance(model, BaseCompositeModel): model.property_links_model = plm def contains_path(self, path): """Is there a caption *path* inside this model.""" model = self for caption in path: if caption in model: model = model[caption] else: return False return True def get_model_from_path(self, path): """*path* is caption path (str).""" model = self for caption in path: model = model[caption] return model def is_model_inside(self, model): """Return True if *model* is inside at any level.""" paths = self.get_leaf_paths() for path in paths: for item in path: if item == model: return True return False def get_path_from_model(self, model): """*model* must be inside this composite model.""" paths = self.get_leaf_paths() for path in paths: for (i, item) in enumerate(path): if item == model: return path[:i + 1] raise KeyError(f'{model} not inside') def get_descendant_graph(self, in_subwindow=False): """ Get all descendant models recursively in case there are composite models inside this model. If *in_subwindow* is True, return models shown to the user in the subwindow, otherwise the ones created at class instantiation. For composites inside this one, if *in_subwindow* is True return the subwindow models, but if it's not being edited instead raising an exception, return the internal models. """ if in_subwindow and not self.is_editing: raise ValueError('in_subwindow True but no subwindow open') graph = nx.DiGraph() def descend(parent): if in_subwindow and parent.is_editing: models = [node.model for node in parent._window_nodes.values()] else: models = [parent[key] for key in parent] for model in models: graph.add_edge(parent, model) if isinstance(model, BaseCompositeModel): descend(model) descend(self) return graph def get_leaf_paths(self, in_subwindow=False): graph = self.get_descendant_graph(in_subwindow=in_subwindow) leaves = [node for node in graph.nodes if graph.out_degree(node) == 0] paths = [] for leaf in leaves: paths.append(list(nx.simple_paths.all_simple_paths(graph, self, leaf))[0]) return paths def restore(self, state, restore_caption=True): self._connections = [CompositeConnection(*args) for args in state['connections']] self._view.restore_groups(state['models']) super().restore(state, restore_caption=restore_caption) def restore_links(self, node): if self.property_links_model: row = self.property_links_model.rowCount() for items in self._links: # A row can be restored only if no property from the state is in the link model # yet row_ok = True for str_path in items: prop_name = str_path[-1] model = self.get_model_from_path(str_path[:-1]) if self.property_links_model.find_items([node, model, prop_name], [NODE_ROLE, MODEL_ROLE, PROPERTY_ROLE]): LOG.info(f'{str_path[-2]}->{prop_name} already in property links') row_ok = False break if row_ok: for (i, str_path) in enumerate(items): model = self.get_model_from_path(str_path[:-1]) self.property_links_model.add_item(node, model, str_path[-1], row, i) row += 1 def save(self): state = {'name': self.name, 'caption': self.caption} state['models'] = self._view.export_groups() for (model, position) in self._models.items(): state['models'][model.caption]['position'] = position # This is necessary for creating models from saved files state['models'][model.caption]['name'] = model.name state['connections'] = [conn.save() for conn in self._connections] if self.property_links_model: state['links'] = [] paths = self.get_leaf_paths() models = [path[-1] for path in paths] items = self.property_links_model.get_model_links(models) for row in items.values(): # First item in the row is this model, skip it state['links'].append([str_path[1:] for str_path in row]) return state def on_connection_created(self, connection): self._other_scene.connection_deleted.disconnect(self.on_connection_deleted) self._other_scene.delete_connection(connection) self._other_scene.connection_deleted.connect(self.on_connection_deleted) def on_connection_deleted(self, connection): self._other_scene.connection_created.disconnect(self.on_connection_created) self._other_scene.restore_connection(connection.__getstate__()) self._other_scene.connection_created.connect(self.on_connection_created) def double_clicked(self, parent): self.edit_in_window(parent=parent) def on_other_scene_double_clicked(self, node): node.model.double_clicked(self._other_view) def expand_into_graph(self, graph): """Expand to submodels in a *graph*, which is a networkx.DiGraph instance.""" name_to_model = {} for model in self._models: LOG.debug(f'Adding node {model.name}') graph.add_node(model) name_to_model[model.caption] = model for conn in self._connections: source = name_to_model[conn.from_unique_name] dest = name_to_model[conn.to_unique_name] LOG.debug(f'Adding edge {source.name}@{conn.from_port_index} -> ' f'{dest.name}@{conn.to_port_index}') graph.add_edge(source, dest, input=conn.to_port_index, output=conn.from_port_index) def _expand_into_scene(self, scene, original_nodes=None, restore_captions=False): # unique name to node instance mapping name_to_node = {} for model in self._models: if original_nodes and model.caption in original_nodes: node = scene.restore_node(original_nodes[model.caption]) else: with saved_kwargs(scene.registry, model.__getstate__()): if restore_captions: node = scene.create_node(model.__class__) else: # This is the main scene, links restoration takes place in expand_into_scene # for all nodes including composites node = scene.create_node(model.__class__, restore_links=False) if isinstance(model, PropertyModel) or isinstance(model, BaseCompositeModel): node.model.restore(model.save(), restore_caption=restore_captions) if isinstance(node.model, BaseCompositeModel): node.model.property_links_model = self.property_links_model else: node.model.restore(model.save()) name_to_node[model.caption] = node if self._models[model] is not None: node.position = (self._models[model]['x'], self._models[model]['y']) for conn in self._connections: f_node = name_to_node[conn.from_unique_name] t_node = name_to_node[conn.to_unique_name] f_port = f_node[PortType.output][conn.from_port_index] t_port = t_node[PortType.input][conn.to_port_index] scene.create_connection(f_port, t_port, check_cycles=False) return name_to_node def add_slave_links(self): self._slave_property_links = [] if not self.property_links_model: return for node in self._window_nodes.values(): if isinstance(node.model, BaseCompositeModel): paths = node.model.get_leaf_paths(in_subwindow=node.model._window_nodes != {}) else: paths = [[node.model]] # Propagate all signals from leaves to the original models for path in paths: str_path = [m.caption for m in path] new_model = path[-1] orig_model = self.get_model_from_path(str_path) # Create a link from this node's model instances to the original root # models in the link model (there can be other composites along the way to # the root root_model = self.property_links_model.get_root_model(orig_model) if root_model: prop_names = self.property_links_model.get_model_properties(root_model) for prop_name in prop_names: if (new_model, prop_name) not in self._slave_property_links: # In order to remove slaves when the subwindow is closed, register # the slaves with respect to the most nested composite registering_model = path[-2] if len(path) > 1 else self if registering_model.is_editing: registering_model._slave_property_links.append((new_model, prop_name)) registering_model.property_links_model.add_silent(new_model, prop_name, root_model, prop_name) if registering_model.window_parent: # If the registering model has a parent, register also the # models in it's internal model view new_model = registering_model[path[-1].caption] registering_model = registering_model.window_parent registering_model._slave_property_links.append((new_model, prop_name)) registering_model.property_links_model.add_silent(new_model, prop_name, root_model, prop_name) def edit_in_window(self, parent=None): self._other_scene = FlowScene(registry=self._registry) self._other_scene.node_double_clicked.connect(self.on_other_scene_double_clicked) self._window_nodes = self._expand_into_scene(self._other_scene, restore_captions=True) # Store references to parent composites for node in self._window_nodes.values(): if isinstance(node.model, BaseCompositeModel): node.model.window_parent = self # Property links have to be registered with respect to the top composite model because # it's property model's property model is registered in property links window_parent = self while window_parent.window_parent: window_parent = window_parent.window_parent window_parent.add_slave_links() # Disable manipulation because the number of ports is fixed, so we can't e.g. internally # connect two nodes and delete the newly occupied port from the composite node self._other_scene.allow_node_creation = False self._other_scene.allow_node_deletion = False # There is no allow_connection_creation/deletion, so take care of it here self._other_scene.connection_created.connect(self.on_connection_created) self._other_scene.connection_deleted.connect(self.on_connection_deleted) self._other_view = FlowView(self._other_scene, parent=parent) self._other_view.setWindowFlag(Qt.Window, True) self._other_view.closeEvent = self.view_close_event self._other_view.setWindowTitle(self.name) self._other_view.resize(900, 600) self._other_view.show() def view_close_event(self, event): for node in self._window_nodes.values(): # Clse all composite children recursively first if isinstance(node.model, BaseCompositeModel) and node.model.is_editing: node.model._other_view.close() node.model.window_parent = None for (unique_name, node) in self._window_nodes.items(): self._view[unique_name].restore(node.model.save()) if self.property_links_model: for (model, prop_name) in self._slave_property_links: self.property_links_model.remove_silent(model, prop_name) self._slave_property_links = [] self._window_nodes = {} self._other_scene = None self._other_view = None def expand_into_scene(self, scene, composite_node, original_nodes=None): """ Expand this node into *scene* and replace *composite_node*'s connections with connections going straight into its subnodes. Also create connections internal to this node and update property links. *original_nodes* is a dictionary in form {caption: node_state} which will be used for positioning of the replacing nodes (scene.restore_node instead of scene.create_node will be called). """ assert self.property_links_model is not None # Connections to external nodes connections = [] # name_to_node is in format caption: new node dictionary # Internal connections are handled in _expand_into_scene name_to_node = self._expand_into_scene(scene, original_nodes=original_nodes, restore_captions=False) for port_type in [PortType.input, PortType.output]: for index, port in composite_node[port_type].items(): if port.connections: connection = port.connections[0] outside_port = connection.valid_ports[opposite_port(port_type)] internal_model, pt, pi = self._inside_ports[(port_type, index)] connections.append((outside_port, name_to_node[internal_model.caption][pt][pi])) # Update property links for (subcaption, subnode) in name_to_node.items(): if isinstance(subnode.model, BaseCompositeModel): # Get all leaf PropertyModel instances paths = subnode.model.get_leaf_paths() else: paths = [[subnode.model]] # In case selected node is composite, replace all leaf node links for path in paths: str_path = [model.caption for model in path] # Captions might have changed if subnode captions were equal to other captions # in the scene and the composite node which is being replaced contains still the # old ones old_str_path = [subcaption] + str_path[1:] old_model = composite_node.model.get_model_from_path(old_str_path) self.property_links_model.replace_item(subnode, path[-1], old_model) subnode.graphics_object.setSelected(True) scene.remove_node(composite_node) # Create outside connections only after the composite node has been deleted to prevent # creating multiple connections per input port in the outside nodes for outside, inside in connections: scene.create_connection(outside, inside, check_cycles=False) return name_to_node, connections def get_composite_model_class(composite_name, models, connections, links=None): if not composite_name: raise UfoModelError('composite name must be specified') class CompositeModel(BaseCompositeModel): name = composite_name data_type = UFO_DATA_TYPE def __init__(self, style=None, parent=None, registry=None): super().__init__(models, connections, links=links, registry=registry, style=style, parent=parent) model = CompositeModel model.caption_visible = True model.caption = composite_name return model class UfoGeneralBackprojectModel(UfoTaskModel): name = 'general_backproject' num_ports = {PortType.input: 1, PortType.output: 1} data_type = UFO_DATA_TYPE def __init__(self, style=None, parent=None, scrollable=True): super().__init__('general-backproject', style=style, parent=parent, scrollable=scrollable) self.needs_fixed_scheduler = True self.can_split_gpu_work = True def make_properties(self): properties = super().make_properties() slice_memory_coeff = NumberQLineEditViewItem(0.01, 1., default_value=0.8, tooltip='Portion of used GPU memory') properties['slice-memory-coeff'] = [slice_memory_coeff, False] return properties def split_gpu_work(self, gpus): from tofu.genreco import make_runs, DTYPE_CL_SIZE def check_region(region): if not len(np.arange(*self[region])): raise UfoModelError(f'Invalid {region} {self[region]}') # Check if ranges are OK check_region('region') check_region('x-region') check_region('y-region') gpu_indices = range(len(gpus)) bpp = DTYPE_CL_SIZE[self['store-type']] runs = make_runs(gpus, gpu_indices, self['x-region'], self['y-region'], self['region'], bpp, slice_memory_coeff=self['slice-memory-coeff']) return runs def _setup_ufo_task(self, ufo_task, region=None): separate = ['region', 'slice-memory-coeff'] task_props = [prop for prop in self if prop not in separate] for prop in task_props: setattr(ufo_task.props, prop, self[prop]) # Set region separately in case there are multiple inputs current_region = self['region'] if region is None else region setattr(ufo_task.props, 'region', current_region) class UfoVaryingInputModel(UfoTaskModel): """Base class for models which can have varying number if inputs.""" def __init__(self, task_name, style=None, parent=None, scrollable=True, num_inputs=None, dialog_title='Number of inputs', dialog_label='Number of inputs:'): if not num_inputs: num_inputs, ok = QInputDialog.getInt(parent, dialog_title, dialog_label, value=1, minValue=1, maxValue=10, step=1) if not ok: raise UfoModelError('Number of inputs must be specified') self.num_ports = {PortType.input: num_inputs, PortType.output: 1} self.data_type = {PortType.output: {0: UFO_DATA_TYPE}} self.port_caption = {PortType.output: {0: ''}} self.port_caption_visible = {PortType.output: {0: False}} self.data_type[PortType.input] = {} self.port_caption[PortType.input] = {} self.port_caption_visible[PortType.input] = {} for i in range(num_inputs): self.data_type[PortType.input][i] = UFO_DATA_TYPE self.port_caption[PortType.input][i] = '' self.port_caption_visible[PortType.input][i] = False super().__init__(task_name, style=style, parent=parent, scrollable=scrollable) def save(self): state = super().save() state['num-inputs'] = self.num_ports['input'] return state class UfoOpenCLModel(UfoVaryingInputModel): name = 'opencl' def __init__(self, style=None, parent=None, scrollable=True, num_inputs=None): super().__init__('opencl', style=style, parent=parent, scrollable=scrollable, num_inputs=num_inputs) def _setup_ufo_task(self, ufo_task, region=None): for prop in self: if prop in ['filename', 'source']: # opencl task really needs NULL value = self[prop] if self[prop] else None else: value = self[prop] setattr(ufo_task.props, prop, value) class UfoReadModel(UfoTaskModel): name = 'read' num_ports = {PortType.input: 0, PortType.output: 1} data_type = UFO_DATA_TYPE def __init__(self, style=None, parent=None, scrollable=True): super().__init__('read', style=style, parent=parent, scrollable=scrollable) def auto_fill(self): import glob import imageio if os.path.isdir(self['path']): paths = sorted(glob.glob(os.path.join(self['path'], '*'))) else: paths = [self['path']] num_images = 0 for path in paths: try: num_images += len(imageio.get_reader(path)) except: LOG.error(f"Error reading '{path}'") if not num_images: raise UfoModelError(f"No images found in {self['path']}") self['number'] = num_images def double_clicked(self, parent): current_path = self['path'] if not os.path.isdir(current_path): current_path = os.path.dirname(current_path) if not current_path: current_path = QtCore.QDir.homePath() dialog = FileDirDialog() if dialog.exec_(): self['path'] = dialog.selectedFiles()[0] def _setup_ufo_task(self, ufo_task, region=None): for prop in self: if prop != 'raw-bitdepth' or self['raw-bitdepth']: setattr(ufo_task.props, prop, self[prop]) class UfoRetrievePhaseModel(UfoVaryingInputModel): name = 'retrieve_phase' def __init__(self, style=None, parent=None, scrollable=True, num_inputs=None): super().__init__('retrieve-phase', style=style, parent=parent, scrollable=scrollable, dialog_title='Multi-distance Setup', dialog_label='Number of distances:', num_inputs=num_inputs) def make_properties(self): properties = super().make_properties() # Override distance property based on how many inputs we expect tooltip = properties['distance'][0].widget.toolTip() item = RangeQLineEditViewItem(tooltip=tooltip, default_value=[], num_items=self.num_ports['input'], is_float=True) properties['distance'] = [item, True] if self.num_ports['input'] > 1: properties['method'][0].set('ctf_multidistance') properties['method'][0].widget.setEnabled(False) properties['distance-x'][0].widget.setEnabled(False) properties['distance-y'][0].widget.setEnabled(False) return properties class UfoWriteModel(UfoTaskModel): name = 'write' num_ports = {PortType.input: 1, PortType.output: 0} data_type = UFO_DATA_TYPE def __init__(self, style=None, parent=None, scrollable=True): super().__init__('write', style=style, parent=parent, scrollable=scrollable) def double_clicked(self, parent): current_path = os.path.dirname(self['filename']) if not current_path: current_path = QtCore.QDir.homePath() file_name, _ = QFileDialog.getSaveFileName(None, "Select File Name", current_path) if file_name: self['filename'] = file_name @property def expects_multiple_inputs(self): return '{region}' in self['filename'] def _setup_ufo_task(self, ufo_task, region=None): if region is not None and not self.expects_multiple_inputs: raise UfoModelError('Write got region without enabling multiple inputs. ' 'Add {region} somewhere in the "filename" field to enable it.') super()._setup_ufo_task(ufo_task, region=region) filename = self['filename'] if region is not None and self.expects_multiple_inputs: filename = filename.format(region=region[0]) setattr(ufo_task.props, 'filename', filename) class _Batch(QObject): finished = pyqtSignal(int) def __init__(self, ufo_task, shape, batch_id): super().__init__(parent=None) self.batch_id = batch_id self.data = np.empty(shape, dtype=np.float32) ptr = self.data.__array_interface__['data'][0] ufo_task.props.pointer = ptr ufo_task.props.max_size = self.data.nbytes ufo_task.connect('processed', self._on_processed) self.num_processed = 0 def _on_processed(self, ufo_task): self.num_processed += 1 if self.num_processed == self.data.shape[0]: self.finished.emit(self.batch_id) class UfoMemoryOutModel(UfoTaskModel): name = 'memory_out' num_ports = {PortType.input: 1, PortType.output: 1} data_type = {PortType.input: {0: UFO_DATA_TYPE}, PortType.output: {0: ARRAY_DATA_TYPE}} port_caption = {PortType.input: {0: ''}, PortType.output: {0: ''}} port_caption_visible = {PortType.input: {0: False}, PortType.output: {0: False}} def __init__(self, style=None, parent=None, scrollable=True): self._lock = Lock() self.reset_batches() super().__init__('memory-out', style=style, parent=parent, scrollable=scrollable) @property def expects_multiple_inputs(self): return self['number'] == '{region}' def make_properties(self): width_item = IntQLineEditViewItem(0, 1000000, default_value=0, tooltip='Input width') height_item = IntQLineEditViewItem(0, 1000000, default_value=0, tooltip='Input height') depth_item = IntQLineEditViewItem(0, 1000000, default_value=1, tooltip='Input depth (for 2D images should be 1)') number_item = QLineEditViewItem(default_value=1, tooltip='Number of inputs') properties = {'width': [width_item, True], 'height': [height_item, True], 'depth': [depth_item, True], 'number': [number_item, True]} return properties def consume_batch(self, batch_id): def consume(current_batch): LOG.debug(f'{self.caption}: consuming {current_batch.batch_id} (caller {batch_id})') self._current_data = current_batch.data self.data_updated.emit(0) # Free memory up self._batches[self._expecting_id] = None with self._lock: if self._expecting_id == batch_id: consume(self._batches[self._expecting_id]) self._expecting_id += 1 while self._expecting_id in self._waiting_list: consume(self._batches[self._expecting_id]) del self._waiting_list[self._waiting_list.index(self._expecting_id)] self._expecting_id += 1 else: LOG.debug(f'{self.caption}: putting {batch_id} on waiting list') self._waiting_list.append(batch_id) def out_data(self, port: int) -> NodeData: LOG.debug(f'{self.caption}: out_data shape:' f'{None if self._current_data is None else self._current_data.shape}') return self._current_data def reset_batches(self): self._batches = [] self._waiting_list = [] self._expecting_id = 0 self._current_data = None def _setup_ufo_task(self, ufo_task, region=None): if region is not None and not self.expects_multiple_inputs: raise UfoModelError('Memory Out got region without enabling multiple inputs. ' 'Type {region} in the "number" field to enable it.') number = int(self['number']) if region is None else len(np.arange(*region)) shape = (number, self['height'], self['width']) with self._lock: batch = _Batch(ufo_task, shape, len(self._batches)) self._batches.append(batch) batch.finished.connect(self.consume_batch) class ImageViewerModel(UfoModel): name = 'image_viewer' caption = 'Image Viewer' num_ports = {PortType.input: 1, PortType.output: 0, } data_type = ARRAY_DATA_TYPE def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._node_data = None from tofu.flow.viewer import ImageViewer self._widget = ImageViewer() self._reset = True def embedded_widget(self): return self._widget def resizable(self): return True def double_clicked(self, parent): try: if self._widget.images is not None and not self._widget.popup_visible: import pyqtgraph self._widget.popup() except ImportError: LOG.debug('pyqtgraph not installed, not popping up') def set_in_data(self, data: NodeData, port: Port): if data is not None: if self._reset: self._widget.images = data self._reset = False else: self._widget.append(data) def reset_batches(self): self._reset = True def cleanup(self): self._widget.cleanup() def get_ufo_model_classes(names=None): all_names = set(UFO_PLUGIN_MANAGER.get_all_task_names()) # stamp causes a gobject unref warning blacklist = set(['general-backproject', 'memory-in', 'memory-out', 'opencl', 'read', 'retrieve-phase', 'stamp', 'write']) all_names = list(all_names - blacklist) return (get_ufo_model_class(name) for name in names or all_names) def get_composite_model_classes_from_json(state): """ Get composite model classes from their json representation. This is recursive in case a user creates a composite inside the scene, then adds nodes and creates another composite with the first one inside and doesn't export explicitly the first one. The order of returned classes is bottom -> up, i.e. first the classes which have striclty non-composite submodels are returned and the top level class is last. """ classes = [] def go_down(current): connections = [CompositeConnection(*args) for args in current['connections']] submodels = [] for (key, model) in current['models'].items(): if 'models' in model['model'] and 'connections' in model['model']: go_down(current['models'][key]['model']) # models are tuples (name, state, visible, position) submodels.append((model['name'], model['model'], model['visible'], model['position'])) classes.append(get_composite_model_class(current['name'], submodels, connections, links=current.get('links', None))) go_down(state) return classes def get_composite_model_classes(): from xdg import xdg_data_home composite_lists = [] paths = [pkg_resources.resource_filename(__name__, 'composites'), os.path.join(xdg_data_home(), 'tofu', 'flows', 'composites')] for path in paths: file_names = sorted(glob.glob(os.path.join(path, '*.cm'))) for file_name in file_names: LOG.debug(f'Loading composite from {file_name}') try: with open(file_name, 'r') as f: state = json.load(f) composite_lists.append(get_composite_model_classes_from_json(state)) except Exception as e: LOG.error(e, exc_info=True) return composite_lists class UfoModelError(FlowError): pass tofu-0.12.0/tofu/flow/propertylinksmodels.py000066400000000000000000000371131423713721100212040ustar00rootroot00000000000000import logging from PyQt5.QtCore import QDataStream, pyqtSignal from PyQt5.QtGui import QStandardItemModel, QStandardItem from tofu.flow.models import PropertyModel, BaseCompositeModel from tofu.flow.util import MODEL_ROLE, NODE_ROLE, PROPERTY_ROLE LOG = logging.getLogger(__name__) def _decode_mime_data(data): byte_array = data.data('application/x-sourcetreemodelindex') ds = QDataStream(byte_array) row = ds.readInt32() column = ds.readInt32() internal_id = ds.readUInt64() return (row, column, internal_id) def _data_from_tree_index(index): """ Traverse parents up to the root and get the root node, model and it's property from *index*, which must be a property record (leaf in the tree). """ prop_name = index.data() index = index.parent() model = index.data(role=MODEL_ROLE) while index.data(role=NODE_ROLE) is None and index.isValid(): index = index.parent() node = index.data(role=NODE_ROLE) return (node, model, prop_name) def _get_string_path(node, model, prop_name): if isinstance(node.model, BaseCompositeModel): path = node.model.get_path_from_model(model) else: path = [model] str_path = [model.caption for model in path] str_path.append(prop_name) return str_path class NodeTreeModel(QStandardItemModel): """Tree model representing nodes in the scene.""" def add_node(self, node): item = self._add_model(node.model) if item: item.setData(node, role=NODE_ROLE) def remove_node(self, node): for j in range(self.rowCount()): item = self.item(j, 0) if item and item.data(role=NODE_ROLE) == node: self.removeRow(j) break def clear(self): """In PyQt5, clear doesn't emit the rowsAboutToBeRemoved signal and this does effectively the same. """ self.removeRows(0, self.rowCount()) self.removeColumns(0, self.columnCount()) self.rowCount(), self.columnCount() def set_nodes(self, nodes): self.clear() for node in nodes: self.add_node(node) def _add_model(self, flow_model, parent=None): if not parent: parent = self.invisibleRootItem() item = None if (isinstance(flow_model, PropertyModel) or isinstance(flow_model, BaseCompositeModel)): item = QStandardItem(flow_model.caption) item.setData(flow_model, role=MODEL_ROLE) item.setEditable(False) if isinstance(flow_model, PropertyModel): for prop in sorted(flow_model): prop_item = QStandardItem(prop) prop_item.setEditable(False) item.appendRow(prop_item) else: for submodel_name in sorted(flow_model): self._add_model(flow_model[submodel_name], parent=item) if item: parent.appendRow(item) return item class PropertyLinksModel(QStandardItemModel): """Links model representing property links between nodes in the scene.""" restored = pyqtSignal() def __init__(self, node_model): super().__init__() self._silent = {} self._slaves = {} self._node_model = node_model self._node_model.rowsAboutToBeRemoved.connect(self.on_node_rows_about_to_be_removed) def __contains__(self, key): for column in range(self.columnCount()): if self.findItems(key, column=column): return True return False def clear(self): for j in range(self.rowCount()): for i in range(self.columnCount()): self.remove_item(self.indexFromItem(self.item(j, i))) super().clear() def find_items(self, data_list, roles): result = [] for j in range(self.rowCount()): for i in range(self.columnCount()): item = self.item(j, i) if item: success = True for (data, role) in zip(data_list, roles): if item.data(role=role) != data: success = False break if success: result.append(item) return result def get_model_links(self, models): """ Get links between *models*. Return dict {row index: [str_path, ...]}, where *str_path* is the path from the topmost model (in case of composites along the way) to the property name. """ items = {} for model in models: for item in self.find_items([model], [MODEL_ROLE]): str_path = item.text().split('->') if item.row() not in items: items[item.row()] = [str_path] else: items[item.row()].append(str_path) return items def get_root_model(self, model): root_model = None items = self.find_items([model], [MODEL_ROLE]) if items: root_model = items[0].data(role=MODEL_ROLE) else: for (silent_model, prop_name) in self._silent: if silent_model == model: root_model = self._silent[(silent_model, prop_name)][0] return root_model def get_model_properties(self, model): items = self.find_items([model], [MODEL_ROLE]) return [item.data(role=PROPERTY_ROLE) for item in items] def add_item(self, node, model, prop_name, row, column, insert=False): """ Add item where *node* is the root node (can be composite), *model* is the leaf model (there can be composites above if the leaf is nested) and *prop_name* is the property name. *row* and *column* determine the table cell to which to add the item or replace an old item with the new one. If *insert* is True, insert a new row at *row*. """ str_path = '->'.join(_get_string_path(node, model, prop_name)) if str_path in self: raise ValueError(f'{str_path} already inside') item = QStandardItem(str_path) item.setData(model, role=MODEL_ROLE) item.setData(prop_name, role=PROPERTY_ROLE) item.setData(node, role=NODE_ROLE) item.setEditable(False) if row == -1: row = self.rowCount() if column == -1: # +1 to find an empty cell even if the row is full for i in range(self.columnCount() + 1): if self.item(row, i) is None: column = i break LOG.debug(f'Add item {node.model.caption}({item.data(role=MODEL_ROLE)}):' f'{item.data(role=PROPERTY_ROLE)} at ({row}, {column})') if insert: self.insertRow(row, item) else: self.setItem(row, column, item) # In case the composite is being edit in a subwindow, connect the slave nodes from the # subsecene if isinstance(node.model, BaseCompositeModel): node.model.add_slave_links() model.property_changed.connect(self.on_property_changed) def remove_item(self, index): flow_model = index.data(role=MODEL_ROLE) if not flow_model: # Empty cell return property_name = index.data(role=PROPERTY_ROLE) flow_model.property_changed.disconnect(self.on_property_changed) self.setItem(index.row(), index.column(), None) # Remove all associated slaves root_key = (flow_model, property_name) if root_key in self._slaves: for slave_key in tuple(self._slaves[root_key]): self.remove_silent(*slave_key) def add_silent(self, model, property_name, root, root_property_name): key = (model, property_name) if key in self._silent: return model.property_changed.connect(self.on_property_changed) root_key = (root, root_property_name) if not self.find_items(root_key, (MODEL_ROLE, PROPERTY_ROLE)): raise ValueError(f'{model} not in property links') self._silent[key] = root_key if root_key not in self._slaves: self._slaves[root_key] = [key] else: self._slaves[root_key].append(key) LOG.debug(f'Slave {root}->{root_property_name} -> {model}->{property_name} added') def remove_silent(self, model, property_name): key = (model, property_name) if key not in self._silent: # Already removed, e.g. by deleting an item by del key while some composite windows were # still opened return model.property_changed.disconnect(self.on_property_changed) root_key = self._silent[key] index = self._slaves[root_key].index(key) del self._slaves[root_key][index] if not self._slaves[root_key]: del self._slaves[root_key] del self._silent[key] LOG.debug(f'Slave {model}->{property_name} removed') def replace_item(self, node, new_model, old_model): for j in range(self.rowCount()): for i in range(self.columnCount()): item = self.item(j, i) if item and item.data(role=MODEL_ROLE) == old_model: # Don't break, replace all properties of *old_model* prop_name = item.data(role=PROPERTY_ROLE) slaves = tuple(self._slaves.get((old_model, prop_name), [])) self.remove_item(self.indexFromItem(item)) self.add_item(node, new_model, prop_name, j, i) for (slave_model, slave_property_name) in slaves: self.add_silent(slave_model, slave_property_name, new_model, prop_name) def on_node_rows_about_to_be_removed(self, parent, first, last): for k in range(first, last + 1): node = self._node_model.item(k, 0).data(role=NODE_ROLE) for j in range(self.rowCount()): for i in range(self.columnCount()): item = self.item(j, i) if item and item.data(role=NODE_ROLE) == node: self.remove_item(self.indexFromItem(item)) self.compact() def canDropMimeData(self, data, action, row, column, parent): can_drop = False if data.hasFormat('application/x-sourcetreemodelindex'): src_row, src_column, src_internal_id = _decode_mime_data(data) src_model_index = self._node_model.createIndex(src_row, src_column, src_internal_id) # src_model_index is the property, it's parent is the model node, flow_model, property_name = _data_from_tree_index(src_model_index) str_path = '->'.join(_get_string_path(node, flow_model, property_name)) can_drop = str_path not in self if parent.isValid(): # Parent itself can be an empty cell, so use the first column which is for sure # occupied since the parent is valid (row exists and we are not between rows) first_item = self.item(parent.row(), 0) parent_model = first_item.data(role=MODEL_ROLE) parent_property_name = first_item.data(role=PROPERTY_ROLE) if not type(flow_model[property_name]) is type(parent_model[parent_property_name]): # Data can be dropped only if the types of properties match can_drop = False return can_drop def dropMimeData(self, data, action, row, column, parent): src_row, src_column, src_internal_id = _decode_mime_data(data) src_model_index = self._node_model.createIndex(src_row, src_column, src_internal_id) node, flow_model, property_name = _data_from_tree_index(src_model_index) if parent.isValid(): row = parent.row() insert = False else: insert = True # drops never replace items and column=-1 means "find an empty cell" self.add_item(node, flow_model, property_name, row, -1, insert=insert) return True def save(self): state = [] for j in range(self.rowCount()): row_state = [] for i in range(self.columnCount()): item = self.item(j, i) if not item: continue node = item.data(role=NODE_ROLE) model = item.data(role=MODEL_ROLE) prop_name = item.data(role=PROPERTY_ROLE) str_path = _get_string_path(node, model, prop_name) row_state.append([node.id, str_path]) state.append(row_state) return state def restore(self, state, nodes): self.clear() for (j, row) in enumerate(state): for (i, (node_id, path)) in enumerate(row): node = nodes[node_id] # Last path entry is the property name if isinstance(node.model, BaseCompositeModel): flow_model = node.model.get_model_from_path(path[1:-1]) else: flow_model = node.model self.add_item(node, flow_model, path[-1], j, i) self.restored.emit() def compact(self): # Shift rows to the left for j in range(self.rowCount()): filled = [] for i in range(self.columnCount()): if self.item(j, i): filled.append(self.takeItem(j, i)) for (i, item) in enumerate(filled): self.setItem(j, i, item) # Check empty rows for j in range(self.rowCount())[::-1]: is_empty = True for i in range(self.columnCount()): if self.item(j, i): is_empty = False if is_empty: self.removeRow(j) # Check empty columns for i in range(self.columnCount())[::-1]: is_empty = True for j in range(self.rowCount()): if self.item(j, i): is_empty = False if is_empty: self.removeColumn(i) def on_property_changed(self, sig_model, sig_property_name, value): LOG.debug(f'on_property_changed: {sig_model}, {sig_model.caption}, ' f'{sig_property_name}, {value}') sig_key = (sig_model, sig_property_name) if sig_key in self._silent: # pyqtSignal came from a composite subwindow, get root model from the silent slave root_key = self._silent[sig_key] root_key[0][root_key[1]] = value else: root_key = (sig_model, sig_property_name) row = -1 for j in range(self.rowCount()): for i in range(self.columnCount()): item = self.item(j, i) if (item and item.data(role=MODEL_ROLE) == root_key[0] and item.data(role=PROPERTY_ROLE) == root_key[1]): row = j break if row != -1: break for i in range(self.columnCount()): item = self.item(row, i) if item: model = item.data(role=MODEL_ROLE) property_name = item.data(role=PROPERTY_ROLE) if root_key != (model, property_name): model[property_name] = value # Notify all slaves key = (model, property_name) if key in self._slaves: for (slave_model, slave_property_name) in self._slaves[key]: if (slave_model, slave_property_name) != (sig_model, sig_property_name): slave_model[slave_property_name] = value tofu-0.12.0/tofu/flow/propertylinkswidget.py000066400000000000000000000070231423713721100212010ustar00rootroot00000000000000from PyQt5.QtCore import QMimeData, Qt, QDataStream, QByteArray, QIODevice, QModelIndex from PyQt5.QtGui import QDrag from PyQt5.QtWidgets import QAbstractItemView, QLabel, QTableView, QTreeView, QVBoxLayout, QWidget def _encode_mime_data(index: QModelIndex): """Encode item in *index* into :class:`QMimeData`.""" mime_data = QMimeData() data = QByteArray() stream = QDataStream(data, QIODevice.WriteOnly) try: stream.writeInt32(index.row()) stream.writeInt32(index.column()) stream.writeUInt64(index.internalId()) finally: stream.device().close() mime_data.setData("application/x-sourcetreemodelindex", data) return mime_data class PropertyLinksView(QTableView): """Table view for displaying node property links.""" def keyPressEvent(self, event): if event.key() == Qt.Key_Delete: model = self.model() for index in self.selectedIndexes(): model.remove_item(index) model.compact() class NodesView(QTreeView): """Tree view displaying nodes in the scene.""" def get_drag_index(self): selected = self.selectedIndexes() if not selected: return index = selected[0] if index.child(0, 0).row() != -1: return return index def mouseMoveEvent(self, event): """All that a mouse *event* can do is start a drag and drop operation.""" index = self.get_drag_index() if not index: return drag = QDrag(self) mime_data = _encode_mime_data(index) drag.setMimeData(mime_data) drag.exec_(Qt.CopyAction) return True class PropertyLinks(QWidget): """Widget displaying nodes in the scene and their property links in one window.""" def __init__(self, node_model, table_model, parent=None): super().__init__(parent=parent, flags=Qt.Window) self.setWindowTitle('Property Links') self.resize(600, 800) self._treeview = NodesView() self._treeview.setHeaderHidden(True) self._treeview.setAlternatingRowColors(True) self._treeview.setDragEnabled(True) self._treeview.setAcceptDrops(False) self._treeview.setModel(node_model) node_model.itemChanged.connect(self.on_node_model_changed) self._table_view = PropertyLinksView() self._table_view.setDragDropOverwriteMode(False) self._table_view.setDragDropMode(QAbstractItemView.DropOnly) table_model.itemChanged.connect(self.on_table_model_changed) table_model.rowsInserted.connect(self.on_table_model_rows_inserted) table_model.restored.connect(self.on_table_model_restored) self._table_view.setModel(table_model) main_layout = QVBoxLayout() main_layout.addWidget(self._treeview) main_layout.addWidget(QLabel('Drag properties from above to the area below')) main_layout.addWidget(self._table_view) self.setLayout(main_layout) def show(self): self._table_view.resizeColumnsToContents() self._treeview.sortByColumn(0, Qt.AscendingOrder) super().show() def on_table_model_changed(self, item): self._table_view.resizeColumnToContents(item.column()) def on_table_model_rows_inserted(self, index, start, stop): self._table_view.resizeColumnToContents(0) def on_table_model_restored(self): self._table_view.resizeColumnsToContents() def on_node_model_changed(self, item): self._treeview.sortByColumn(0, Qt.AscendingOrder) tofu-0.12.0/tofu/flow/runslider.py000066400000000000000000000160531423713721100170620ustar00rootroot00000000000000from functools import partial from PyQt5.QtCore import Qt, pyqtSignal, QTimer from PyQt5 import QtGui from PyQt5.QtWidgets import QGridLayout, QLineEdit, QWidget, QSlider from tofu.flow.models import IntQLineEditViewItem, RangeQLineEditViewItem, UfoIntValidator from tofu.flow.util import FlowError class RunSlider(QWidget): value_changed = pyqtSignal(float) def __init__(self, parent=None): super().__init__(parent=parent, flags=Qt.Window) self.setWindowFlag(Qt.WindowStaysOnTopHint) self.setMaximumHeight(20) self.setMinimumWidth(600) self.min_edit = QLineEdit() self.min_edit.setToolTip('Minimum') self.min_edit.setMaximumWidth(80) self.min_edit.editingFinished.connect(self.on_min_edit_editing_finished) self.current_edit = QLineEdit() self.current_edit.setToolTip('Current value') self.current_edit.editingFinished.connect(self.on_current_edit_editing_finished) self.max_edit = QLineEdit() self.max_edit.setToolTip('Maximum') self.max_edit.setMaximumWidth(80) self.max_edit.editingFinished.connect(self.on_max_edit_editing_finished) self.slider = QSlider(orientation=Qt.Horizontal) self.slider.setMinimum(0) self.slider.setMaximum(100) self.slider.valueChanged.connect(self.on_slider_value_changed) main_layout = QGridLayout() main_layout.addWidget(self.current_edit, 0, 0, 1, 3, Qt.AlignHCenter) main_layout.addWidget(self.min_edit, 1, 0) main_layout.addWidget(self.slider, 1, 1) main_layout.addWidget(self.max_edit, 1, 2) self.setLayout(main_layout) self.view_item = None self.real_minimum = 0 self.real_maximum = 100 self.real_span = 100 self.type = None self._last_value = None self.setEnabled(False) def _update_range(self, current=None): self.real_span = self.real_maximum - self.real_minimum if current is not None: self.slider.blockSignals(True) self.slider.setValue(int(round((current - self.real_minimum) / self.real_span * 100))) self.slider.blockSignals(False) def get_real_value(self): # First convert possible exponents to float (in case UFO has huge defaults set) return self.type(float(self.current_edit.text())) def set_widget_value(self): value = self.get_real_value() self._last_value = value if isinstance(self.view_item, RangeQLineEditViewItem): value = [value] self.view_item.set(value) # Notify linked widgets self.view_item.property_changed.emit(self.view_item) def set_current_validator(self): if self.type == int: validator = UfoIntValidator(self.real_minimum, self.real_maximum) else: validator = QtGui.QDoubleValidator(self.real_minimum, self.real_maximum, 1000) self.current_edit.setValidator(validator) def setup(self, view_item): if self.view_item == view_item: return False current = view_item.get() if isinstance(view_item, RangeQLineEditViewItem): if len(current) > 1: return False self.type = float current = current[0] d_current = 0.1 * abs(current) if current else 100 self.real_minimum = current - d_current self.real_maximum = current + d_current else: self.type = int if isinstance(view_item, IntQLineEditViewItem) else float self.real_minimum = view_item.widget.validator().bottom() self.real_maximum = view_item.widget.validator().top() self.view_item = view_item self._update_range(current=current) _set_number(self.min_edit, self.real_minimum) _set_number(self.max_edit, self.real_maximum) _set_number(self.current_edit, current) self._last_value = current self.setEnabled(True) self.set_current_validator() return True def reset(self): self.real_minimum = 0 self.real_maximum = 100 self.real_span = 100 self._last_value = None self.type = None self.min_edit.setText('') self.max_edit.setText('') self.current_edit.setText('') self.setWindowTitle('') self.view_item = None self.setEnabled(False) def on_slider_value_changed(self, value): def delayed_update(init_value): current_value = self.slider.value() if init_value == current_value: self.set_widget_value() self.value_changed.emit(real_value) if self.view_item: real_value = self.slider.value() / 100 * self.real_span + self.real_minimum self.current_edit.setText('{:g}'.format(self.type(real_value))) func = partial(delayed_update, value) QTimer.singleShot(100, func) def on_current_edit_editing_finished(self): if not self.view_item: return try: value = self.type(self.current_edit.text()) except ValueError: raise RunSliderError('Not a number') if value == self._last_value: # Nothing new, do not emit value_changed signal in case the app is closing return self.slider.blockSignals(True) self.slider.setValue(int(round((value - self.real_minimum) / self.real_span * 100))) self.slider.blockSignals(False) self.set_widget_value() self.value_changed.emit(value) def on_min_edit_editing_finished(self): if not self.view_item: return try: value = self.type(self.min_edit.text()) except ValueError: raise RunSliderError('Not a number') if value >= self.real_maximum: raise RunSliderError('Minimum must be smaller than maximum') current = self.get_real_value() self.real_minimum = value if current < self.real_minimum: current = self.real_minimum self.current_edit.setText('{:g}'.format(current)) self.set_widget_value() self.value_changed.emit(current) self._update_range(current=current) self.set_current_validator() def on_max_edit_editing_finished(self): if not self.view_item: return try: value = self.type(self.max_edit.text()) except ValueError: raise RunSliderError('Not a number') if value <= self.real_minimum: raise RunSliderError('Maximum must be greater than minimum') current = self.get_real_value() self.real_maximum = value if current > self.real_maximum: current = self.real_maximum self.current_edit.setText('{:g}'.format(current)) self.set_widget_value() self.value_changed.emit(current) self._update_range(current=current) self.set_current_validator() def _set_number(edit, number): edit.setText('{:g}'.format(number)) class RunSliderError(FlowError): pass tofu-0.12.0/tofu/flow/scene.py000066400000000000000000000415761423713721100161600ustar00rootroot00000000000000import logging import numpy as np import networkx as nx from PyQt5.QtCore import pyqtSignal, QObject from PyQt5.QtWidgets import QInputDialog from qtpynodeeditor import FlowScene, NodeDataModel, PortType, opposite_port from tofu.flow.models import (BaseCompositeModel, ImageViewerModel, PropertyModel, UFO_DATA_TYPE, get_composite_model_class, get_composite_model_classes_from_json) from tofu.flow.util import CompositeConnection, FlowError, saved_kwargs from tofu.flow.propertylinksmodels import PropertyLinksModel, NodeTreeModel LOG = logging.getLogger(__name__) class UfoScene(FlowScene): nodes_duplicated = pyqtSignal(list, dict) # view item, its name and model name item_focus_in = pyqtSignal(QObject, str, str, NodeDataModel) def __init__(self, registry=None, style=None, parent=None, allow_node_creation=True, allow_node_deletion=True): super().__init__(registry=registry, style=style, parent=parent, allow_node_creation=allow_node_creation, allow_node_deletion=allow_node_deletion) self._composite_nodes = {} self._selected_nodes_on_disabled = [] self.node_model = NodeTreeModel() self.node_model.setColumnCount(1) self.property_links_model = PropertyLinksModel(self.node_model) self.style_collection.node.opacity = 1 self.style_collection.connection.use_data_defined_colors = True self.node_double_clicked.connect(self.on_node_double_clicked) def __getstate__(self): state = super().__getstate__() state['property-links'] = self.property_links_model.save() return state def __setstate__(self, doc): for node in doc['nodes']: model = node['model'] if 'models' in model and 'connections' in model: # First register the composite model models = get_composite_model_classes_from_json(model) for model in models: self.registry.register_model(model, category='Composite', registry=self.registry) # Restore the scene super().__setstate__(doc) # and the property link models and widgets if 'property-links' in doc: self.node_model.set_nodes(self.nodes.values()) self.property_links_model.restore(doc['property-links'], self.nodes) def create_node(self, data_model, restore_links=True): """Overrides :class:`FlowScene` in order to create a node with *data_model* with a unique caption. """ LOG.debug(f'Create node with model {data_model}') node = super().create_node(data_model) self._setup_new_node(node) if restore_links and isinstance(node.model, BaseCompositeModel): node.model.restore_links(node) return node def restore_node(self, node_json): LOG.debug(f"Restore node with model {node_json['model']['name']}") with saved_kwargs(self.registry, node_json['model']): node = super().restore_node(node_json) self._setup_new_node(node) return node def on_item_focus_in(self, view_item, prop_name, caption, model): self.item_focus_in.emit(view_item, prop_name, caption, model) def _setup_new_node(self, node): self._set_unique_caption(node) self.node_model.add_node(node) if isinstance(node.model, BaseCompositeModel): node.model.property_links_model = self.property_links_model node.model.item_focus_in.connect(self.on_item_focus_in) def _set_unique_caption(self, new_node): caption = new_node.model.caption captions = [node.model.caption for node in self.nodes.values() if node != new_node] if caption in captions: fmt = new_node.model.base_caption + ' {}' i = 2 while fmt.format(i) in captions: i += 1 caption = fmt.format(i) new_node.model.caption = caption def remove_node(self, node): if hasattr(node.model, 'cleanup'): node.model.cleanup() if (isinstance(node.model, BaseCompositeModel) and node.model.name in self._composite_nodes): del self._composite_nodes[node.model.name] self.node_model.remove_node(node) super().remove_node(node) def is_selected_one_composite(self): result = False nodes = self.selected_nodes() if len(nodes) == 1: result = isinstance(nodes[0].model, BaseCompositeModel) return result def skip_nodes(self): selected_nodes = self.selected_nodes() # First check if the selected nodes may be skipped for node in selected_nodes: if (node.model.num_ports[PortType.input] != 1 or node.model.num_ports[PortType.output] != 1): raise FlowError('Only nodes with one input and one output can be skipped') ports = list(node.state.ports) if ports[0].data_type != UFO_DATA_TYPE or ports[1].data_type != UFO_DATA_TYPE: raise FlowError('Only tasks with UFO input and output can be skipped') # And only if all is fine, then skip them for node in selected_nodes: node.model.skip = not node.model.skip opacity = 0.5 if node.model.skip else 1 node.state.input_connections[0].graphics_object.setOpacity(opacity) node.state.output_connections[0].graphics_object.setOpacity(opacity) node.graphics_object.setOpacity(opacity) def auto_fill(self): for node in self.nodes.values(): if isinstance(node.model, BaseCompositeModel): paths = node.model.get_leaf_paths() else: paths = [[node.model]] for path in paths: model = path[-1] if isinstance(model, PropertyModel): model.auto_fill() def copy_nodes(self): new_nodes = {} selected_nodes = self.selected_nodes() # Create nodes for node in selected_nodes: new_node = self.create_node(node.model) new_nodes[node] = new_node values = node.model.save() new_node.model.restore(values, restore_caption=False) # Create connections for node, new_node in new_nodes.items(): for connection in self.connections: port = connection.ports[0] in_index = port.index out_index = connection.ports[1].index if port.node == node: other_node = connection.ports[1].node if other_node in new_nodes: # Other node has been also selected self.create_connection_by_index(new_node, in_index, new_nodes[other_node], out_index, None) self.nodes_duplicated.emit(selected_nodes, new_nodes) def create_composite(self): composite_name, ok = QInputDialog.getText(None, 'Create Composite Node', 'Name:') if not ok: return if composite_name in self.registry.registered_model_creators(): raise FlowError(f'Composite node with name "{composite_name}" has already ' 'been registered') self._composite_nodes[composite_name] = {} connection_replacements = [] models = [] connections = [] selected_nodes = self.selected_nodes() for node in selected_nodes: unique_name = node.model.caption models.append((node.model.name, node.model.save(), True, node.__getstate__()['position'])) self._composite_nodes[composite_name][unique_name] = node.__getstate__() # Connections assigned_ports = [] x = [] y = [] for node in selected_nodes: x.append(node.position.x()) y.append(node.position.y()) for port_type in ['input', 'output']: for index, port in node[port_type].items(): if port.connections: # We allow only one connection conn = port.connections[0] other_port = conn.ports[0] if conn.ports[1] == port else conn.ports[1] other = conn.get_node(opposite_port(port_type)) if (other in selected_nodes and port not in assigned_ports and other_port not in assigned_ports): # Connection reaches to a node outside selection if port_type == PortType.input: to_node_name = node.model.caption to_node_index = index from_node_name = other.model.caption from_node_index = other_port.index else: to_node_name = other.model.caption to_node_index = other_port.index from_node_name = node.model.caption from_node_index = index conn = CompositeConnection(from_node_name, from_node_index, to_node_name, to_node_index) connections.append(conn) assigned_ports.append(port) if other not in selected_nodes: inside = (node.model.caption, port_type, index) connection_replacements.append((other_port, inside)) # Get links which will be internal to the newly created model node_models = [] for selected_node in self.selected_nodes(): if isinstance(selected_node.model, BaseCompositeModel): paths = selected_node.model.get_leaf_paths() else: paths = [[selected_node.model]] node_models += [path[-1] for path in paths] internal_links = list(self.property_links_model.get_model_links(node_models).values()) composite = get_composite_model_class(composite_name, models, connections, links=internal_links) self.registry.register_model(composite, category='Composite', registry=self.registry) node = self.create_node(composite, restore_links=False) for selected_node in selected_nodes: if isinstance(selected_node.model, BaseCompositeModel): # Get all leaf PropertyModel instances paths = selected_node.model.get_leaf_paths() else: paths = [[selected_node.model]] # In case selected node is composite, replace all leaf node links for path in paths: new_model = node.model.get_model_from_path([model.caption for model in path]) self.property_links_model.replace_item(node, new_model, path[-1]) self.remove_node(selected_node) for outside_port, inside in connection_replacements: port_type, index = node.model.get_outside_port(*inside) self.create_connection(outside_port, node[port_type][index], check_cycles=False) # Put the new composite node to the average of x and y position of the selected nodes node.position = (np.mean(x), np.mean(y)) node.graphics_object.setSelected(True) return node def on_node_double_clicked(self, node): views = self.views() if views: node.model.double_clicked(views[0]) def expand_composite(self, node): name = node.model.name original_nodes = self._composite_nodes.get(name, None) return node.model.expand_into_scene(self, node, original_nodes=original_nodes) def is_fully_connected(self): """Are all the ports in all nodes connected?""" def are_ports_connected(node, port_type): for port in node[port_type].values(): if not port.connections: return False return True for node in self.nodes.values(): if not are_ports_connected(node, 'input'): return False if not are_ports_connected(node, 'output'): return False return True def get_simple_node_graphs(self): """ Get a graph from the scene without composite nodes which can be directly used byt the execution. """ def get_composite(graph): """Get first found composite model.""" for model in graph.nodes: if isinstance(model, BaseCompositeModel): return model def replace_edge(graph, composite, edges, port_type): """Replace interface edges (going in or out from the composite model).""" for edge in edges: ports = graph.edges[edge] other = edge[0] if port_type == PortType.input else edge[1] model, index = composite.get_model_and_port_index(port_type, ports[port_type]) if model not in graph: graph.add_node(model) if port_type == PortType.input: source = other dest = model input_port = index output_port = ports[PortType.output] else: source = model dest = other input_port = ports[PortType.input] output_port = index LOG.debug(f'Adding edge {source.name}@{output_port} -> {dest.name}@{input_port}') graph.add_edge(source, dest, input=input_port, output=output_port) def replace_composite(graph, composite): composite.expand_into_graph(graph) edges = graph.in_edges(composite, keys=True) replace_edge(graph, composite, edges, PortType.input) edges = graph.out_edges(composite, keys=True) replace_edge(graph, composite, edges, PortType.output) graph.remove_node(composite) # Initial graph with composite nodes. We need a multigraph because composite nodes may have # many outputs which can lead to a same destination node. graph = nx.MultiDiGraph() for node in self.nodes.values(): if not node.model.skip: graph.add_node(node.model) for conn in self.connections: p_dest, p_source = conn.ports if p_dest.node.model.skip: LOG.debug(f'Skiping connection {p_source.node.model.name} -> ' f'{p_dest.node.model.name}') continue while p_source.node.model.skip: LOG.debug(f'Skiping connection {p_source.node.model.name} -> ' f'{p_dest.node.model.name}') previous_conn = p_source.node.state.input_connections[0] previous_node = previous_conn.output_node p_source = list(previous_node.state.output_ports)[0] graph.add_edge(p_source.node.model, p_dest.node.model, input=p_dest.index, output=p_source.index) # Expand composite nodes until there are only simple ones left model = get_composite(graph) while model: LOG.debug(f'Replacing composite {model.name}') replace_composite(graph, model) model = get_composite(graph) components = nx.weakly_connected_components(graph) return [nx.subgraph(graph, component) for component in components] def set_enabled(self, enabled): selected_nodes = self.selected_nodes() self.allow_node_creation = enabled self.allow_node_deletion = enabled for node in self.nodes.values(): if not isinstance(node.model, ImageViewerModel): node.graphics_object.setEnabled(enabled) if enabled: if node in self._selected_nodes_on_disabled: node.graphics_object.setSelected(True) else: if node in selected_nodes: self._selected_nodes_on_disabled.append(node) for conn in self.connections: conn._graphics_object.setEnabled(enabled) if enabled: self._selected_nodes_on_disabled = [] tofu-0.12.0/tofu/flow/util.py000066400000000000000000000044321423713721100160260ustar00rootroot00000000000000import contextlib import json import pkg_resources from PyQt5.QtCore import Qt from qtpynodeeditor import PortType MODEL_ROLE = Qt.UserRole + 1 PROPERTY_ROLE = MODEL_ROLE + 1 NODE_ROLE = PROPERTY_ROLE + 1 with open(pkg_resources.resource_filename(__name__, 'config.json')) as f: ENTRIES = json.load(f) def get_config_key(*keys, default=None): current = ENTRIES.get(keys[0], default) if current != default and len(keys) > 1: for key in keys[1:]: current = current.get(key, default) if current == default: break return current @contextlib.contextmanager def saved_kwargs(registry, state): """ Tell the registry to use the number of saved inputs for model creation but only for one model creation, i.e. reset the context afterward. """ if 'num-inputs' in state: kwargs = registry.registered_model_creators()[state['name']][1] kwargs['num_inputs'] = state['num-inputs'] try: yield finally: if 'num-inputs' in state: del kwargs['num_inputs'] class CompositeConnection: def __init__(self, from_unique_name, from_port_index, to_unique_name, to_port_index): if from_unique_name == to_unique_name: raise ValueError('from_unique_name and to_unique_name must be different') self.from_unique_name = from_unique_name self.from_port_index = from_port_index self.to_unique_name = to_unique_name self.to_port_index = to_port_index def contains(self, unique_name, port_type, port_index): is_from = is_to = False if port_type == PortType.output: is_from = (unique_name == self.from_unique_name and port_index == self.from_port_index) else: is_to = (unique_name == self.to_unique_name and port_index == self.to_port_index) return is_from or is_to def save(self): return [self.from_unique_name, self.from_port_index, self.to_unique_name, self.to_port_index] def __str__(self): return repr(self) def __repr__(self): fmt = 'Connection({}@{} -> {}@{})' return fmt.format(self.from_unique_name, self.from_port_index, self.to_unique_name, self.to_port_index) class FlowError(Exception): pass tofu-0.12.0/tofu/flow/viewer.py000066400000000000000000000451001423713721100163470ustar00rootroot00000000000000import logging import numpy as np import os from PyQt5 import QtGui from PyQt5.QtCore import Qt from PyQt5.QtWidgets import QFileDialog, QGridLayout, QLabel, QLineEdit, QMenu, QWidget, QSlider from tofu.flow.util import FlowError LOG = logging.getLogger(__name__) class ScreenImage: """On-screen image representation.""" def __init__(self, image=None): self._black_point = None self._white_point = None self.minimum = None self.maximum = None self.image = image @property def image(self): return self._image @image.setter def image(self, image): """ Keep the minimum, maximum, black and white points as they are so that images don't flicker when going through a sequence. """ self._image = image if self._image is not None: self._image = image.astype(np.float32) if self.minimum is None: self.minimum = np.nanmin(self._image) if self.maximum is None: self.maximum = np.nanmax(self._image) if self.black_point is None: self.black_point = self.minimum if self.white_point is None: self.white_point = self.maximum @property def white_point(self): return self._white_point @white_point.setter def white_point(self, value): if self.black_point is not None and value < self.black_point: raise ImageViewingError('White point cannot be smaller than black point') self._white_point = value @property def black_point(self): return self._black_point @black_point.setter def black_point(self, value): if self.white_point is not None and value > self.white_point: raise ImageViewingError('Black point cannot be greater than white point') self._black_point = value def reset(self): """Reset black and white points.""" if self._image is not None: self.minimum = np.nanmin(self._image) self.maximum = np.nanmax(self._image) self._black_point = self.minimum self._white_point = self.maximum def auto_levels(self, percentile=0.1): """ Compute cumulative histogram normalized to [0, 100] and truncate gray values which fall below *percentile* or above 100 - *percentile*. """ hist, bins = np.histogram(self._image, bins=256) cumsum = np.cumsum(hist) / float(np.sum(hist)) * 100 valid = bins[np.where((cumsum > percentile) & (cumsum < 100 - percentile))] if len(valid): self.black_point = valid[0] self.white_point = valid[-1] else: self.black_point = self.white_point = self._image[0, 0] def set_black_point_normalized(self, value): """Set black point according to *value*, where value is from interval [0, 255].""" native = self.convert_normalized_value_to_native(value) if native > self.white_point: raise ImageViewingError('Black point cannot be greater than white point') self.black_point = native def set_white_point_normalized(self, value): """Set white point according to *value*, where value is from interval [0, 255].""" native = self.convert_normalized_value_to_native(value) if native < self.black_point: raise ImageViewingError('White point cannot be smaller than white point') self.white_point = native def convert_normalized_value_to_native(self, value): """Convert *value* from interval [0, 255] to the gray value in the image.""" if value < 0 or value > 255: raise ImageViewingError('Normalized value must be in interval [0, 255]') span = self.maximum - self.minimum return value / 255 * span + self.minimum def convert_native_value_to_normalized(self, value): """Convert gray value in the image to a normalized value in interval [0, 255].""" if value < self.minimum or value > self.maximum: raise ImageViewingError(f'Value must be in interval [{self.minimum}, {self.maximum}]') span = self.maximum - self.minimum return (value - self.minimum) / span * 255 if span > 0 else 0 def get_pixmap(self, downsampling=1): """Get :class:`QPixmap` for display.""" if self.black_point is None or self.white_point is None: raise ImageViewingError('Image has not been set') image = self.image[::downsampling, ::downsampling] - self.black_point if self.white_point - self.black_point > 0: image = np.clip(image * 255 / (self.white_point - self.black_point), 0, 255) image = image.astype(np.uint8) qim = QtGui.QImage(image, image.shape[1], image.shape[0], image[0].nbytes, QtGui.QImage.Format.Format_Grayscale8) return QtGui.QPixmap.fromImage(qim) class ImageLabel(QLabel): """QLabel holding the image data.""" def __init__(self, screen_image=None, parent=None): super().__init__(parent=parent) self.screen_image = screen_image def updateImage(self): if self.screen_image and self.screen_image.image is not None: hd = self.screen_image.image.shape[1] // self.width() vd = self.screen_image.image.shape[0] // self.height() downsampling = max(min(hd, vd), 1) pixmap = self.screen_image.get_pixmap(downsampling=downsampling) self.setPixmap(pixmap.scaled(self.width(), self.height(), Qt.KeepAspectRatio)) def resizeEvent(self, event): self.updateImage() class ImageViewer(QWidget): edit_height = 16 edit_width = 100 def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._images = None self._last_save_dir = '.' # Pyqtgraph popped up window self._pg_window = None self.screen_image = ScreenImage() self.new_image_auto_levels = True self.label = ImageLabel(self.screen_image) self.label.setAlignment(Qt.AlignVCenter | Qt.AlignCenter) self.slider_edit = QLineEdit() self.slider_edit.setFixedSize(self.edit_width, self.edit_height) self.slider_edit.returnPressed.connect(self.on_slider_edit_return_pressed) self.slider = QSlider(Qt.Horizontal) validator = QtGui.QIntValidator(0, self.slider.maximum()) self.slider_edit.setValidator(validator) self.slider.valueChanged.connect(self.on_slider_value_changed) self.min_slider = QSlider(Qt.Horizontal) self.max_slider = QSlider(Qt.Horizontal) self.min_slider_edit = QLineEdit() self.min_slider_edit.setFixedSize(self.edit_width, self.edit_height) self.max_slider_edit = QLineEdit() self.max_slider_edit.setFixedSize(self.edit_width, self.edit_height) self.min_slider.setMinimum(0) self.max_slider.setMinimum(0) self.min_slider.setMaximum(255) self.max_slider.setMaximum(255) self.max_slider.setValue(255) self.min_slider.valueChanged.connect(self.on_min_slider_value_changed) self.max_slider.valueChanged.connect(self.on_max_slider_value_changed) self.min_slider_edit.returnPressed.connect(self.on_min_slider_edit_return_pressed) self.max_slider_edit.returnPressed.connect(self.on_max_slider_edit_return_pressed) # Tooltips self.slider.setToolTip('Image index in sequence') self.slider_edit.setToolTip(self.slider.toolTip()) self.min_slider.setToolTip('Black point') self.min_slider_edit.setToolTip(self.min_slider.toolTip()) self.max_slider.setToolTip('White point') self.max_slider_edit.setToolTip(self.min_slider.toolTip()) mainLayout = QGridLayout() mainLayout.addWidget(self.label, 0, 0, 1, 2) mainLayout.addWidget(self.slider_edit, 1, 0) mainLayout.addWidget(self.slider, 1, 1) mainLayout.addWidget(self.min_slider_edit, 2, 0) mainLayout.addWidget(self.min_slider, 2, 1) mainLayout.addWidget(self.max_slider_edit, 3, 0) mainLayout.addWidget(self.max_slider, 3, 1) self.setLayout(mainLayout) def contextMenuEvent(self, event): contextMenu = QMenu(self) reset_action = contextMenu.addAction('Reset') auto_levels_action = contextMenu.addAction('Auto Levels') new_image_auto_levels = contextMenu.addAction('Auto Levels on New Image') new_image_auto_levels.setCheckable(True) new_image_auto_levels.setChecked(self.new_image_auto_levels) pop_action = None save_action = None try: import pyqtgraph if self._images is not None and not self.popup_visible: pop_action = contextMenu.addAction('Pop Up') except: LOG.debug('pyqtgraph not installed, pop up option disabled') try: import imageio if self._images is not None: save_action = contextMenu.addAction('Save') except: LOG.debug('imageio not installed, save option disabled') action = contextMenu.exec_(self.mapToGlobal(event.pos())) if not action: return if action == save_action: file_name, _ = QFileDialog.getSaveFileName(None, "Select File Name", self._last_save_dir, "Images (*.tif *.png *.jpg)") if file_name: if not os.path.splitext(file_name)[1]: file_name += '.tif' self._last_save_dir = os.path.dirname(file_name) if self._images.shape[0] == 1: imageio.imsave(file_name, self._images[0]) else: if os.path.splitext(file_name)[1] != '.tif': raise ImageViewingError('3D data can be stored only in tif format') # bigtiff size from tifffile imageio.volsave(file_name, self._images, bigtiff=self._images.nbytes > 2 ** 32 - 2 ** 25) elif action == reset_action: self.reset_clim() elif action == auto_levels_action: self.reset_clim(auto=True) elif action == new_image_auto_levels: self.new_image_auto_levels = action.isChecked() elif action == pop_action: self.popup() @property def images(self): return self._images @images.setter def images(self, images): was_none = self._images is None self._images = images if self._images is None: self.screen_image.image = None self.set_enabled_adjustments(False) return self.set_enabled_adjustments(True) if self._images.ndim == 2: self._images = self._images[np.newaxis, :, :] if self._images.shape[0] == 1: self.slider.hide() self.slider_edit.hide() else: self.slider.setMaximum(len(self._images) - 1) self.slider.show() self.slider_edit.show() self.slider_edit.setText('0') self.slider.blockSignals(True) self.slider.setValue(0) self.slider.blockSignals(False) if self._pg_window is not None: self._update_pg_window_images() self._update_pg_window_index() self.screen_image.image = self._images[0] if was_none or self.new_image_auto_levels: self.reset_clim(auto=True) else: self.label.updateImage() validator = self.min_slider_edit.validator() if validator is None: validator = QtGui.QDoubleValidator(self.screen_image.minimum, self.screen_image.maximum, 100) self.min_slider_edit.setValidator(validator) self.max_slider_edit.setValidator(validator) else: validator.setRange(self.screen_image.minimum, self.screen_image.maximum, 100) self.slider_edit.validator().setTop(self.slider.maximum()) if self.label.width() < 256 or self.label.height() < 256: self.label.resize(256, 256) def append(self, images): if self.images is None: self.images = images else: if images.ndim == 2: images = images[np.newaxis, :, :] if images.shape[1:] != self.images.shape[1:]: raise ImageViewingError('Appended images have different shape ' f'{images.shape[1:]} than the displayed ones ' f'{self.images.shape[1:]}') self.images = np.concatenate((self.images, images)) def set_enabled_adjustments(self, enabled): self.slider.setEnabled(enabled) self.slider_edit.setEnabled(enabled) self.min_slider.setEnabled(enabled) self.min_slider_edit.setEnabled(enabled) self.max_slider.setEnabled(enabled) self.max_slider_edit.setEnabled(enabled) def reset_clim(self, auto=False): self.screen_image.reset() if auto: self.screen_image.auto_levels() self.min_slider_edit.setText('{:g}'.format(self.screen_image.black_point)) self.max_slider_edit.setText('{:g}'.format(self.screen_image.white_point)) self.set_slider_value(self.min_slider, self.screen_image.black_point) self.set_slider_value(self.max_slider, self.screen_image.white_point) self.label.updateImage() self._update_pg_window_lut() @property def popup_visible(self): return self._pg_window and self._pg_window.isVisible() def popup(self): import pyqtgraph pyqtgraph.setConfigOptions(antialias=True, imageAxisOrder='row-major') if self._pg_window is not None: if not self._pg_window.isVisible(): self._pg_window.show() return def on_pg_window_time_changed(index, time): self._set_index(index) self.slider.blockSignals(True) self.slider_edit.setText(str(index)) self.slider.setValue(index) self.slider.blockSignals(False) def on_pg_window_levels_changed(hist_item): minimum, maximum = hist_item.getLevels() if (self.screen_image.minimum <= minimum <= self.screen_image.maximum and self.screen_image.minimum <= maximum <= self.screen_image.maximum): self.min_slider_edit.setText('{:g}'.format(minimum)) self.set_slider_value(self.min_slider, minimum) self.max_slider_edit.setText('{:g}'.format(maximum)) self.set_slider_value(self.max_slider, maximum) self.screen_image.black_point = minimum self.screen_image.white_point = maximum self.label.updateImage() def pg_mouse_moved(ev): if self._pg_window.imageItem.sceneBoundingRect().contains(ev): pos = self._pg_window.imageItem.mapFromScene(ev) x = int(pos.x() + 0.5) y = int(pos.y() + 0.5) self._pg_window.view.setTitle('x={}, y={}, I={:g}'.format(x, y, self._pg_window.imageItem.image[y, x])) else: self._pg_window.view.setTitle('') self._pg_window = pyqtgraph.ImageView(view=pyqtgraph.PlotItem()) self._pg_window.imageItem.scene().sigMouseMoved.connect(pg_mouse_moved) self._pg_window.setWindowFlag(Qt.SubWindow, True) self._update_pg_window_images() self._update_pg_window_index() self._update_pg_window_lut() self._pg_window.show() self._pg_window.sigTimeChanged.connect(on_pg_window_time_changed) self._pg_window.ui.histogram.item.sigLevelsChanged.connect(on_pg_window_levels_changed) def cleanup(self): if self._pg_window: self._pg_window.close() self._pg_window = None def _set_index(self, index): self.screen_image.image = self.images[index] self.label.updateImage() def _update_pg_window_images(self): if self.images.shape[0] == 1: im_to_set = self.images[0] else: im_to_set = self.images self._pg_window.setImage(im_to_set, autoLevels=False) def _update_pg_window_index(self): if self._images.shape[0] > 1 and self._pg_window is not None: self._pg_window.blockSignals(True) self._pg_window.setCurrentIndex(self.slider.value()) self._pg_window.blockSignals(False) def _update_pg_window_lut(self): if self._pg_window is not None: self._pg_window.ui.histogram.item.blockSignals(True) self._pg_window.setLevels(self.screen_image.black_point, self.screen_image.white_point) self._pg_window.ui.histogram.item.blockSignals(False) def on_slider_value_changed(self, value): self._set_index(value) self.slider_edit.setText(str(value)) self._update_pg_window_index() def on_slider_edit_return_pressed(self): self.slider.setValue(int(self.slider_edit.text())) def on_min_slider_edit_return_pressed(self): value = float(self.min_slider_edit.text()) if value < self.screen_image.white_point: self.screen_image.black_point = value self.set_slider_value(self.min_slider, value) self.label.updateImage() self._update_pg_window_lut() def on_max_slider_edit_return_pressed(self): value = float(self.max_slider_edit.text()) if value > self.screen_image.black_point: self.screen_image.white_point = value self.set_slider_value(self.max_slider, value) self.label.updateImage() self._update_pg_window_lut() def on_min_slider_value_changed(self, value): self.screen_image.set_black_point_normalized(value) self.min_slider_edit.setText('{:g}'.format(self.screen_image.black_point)) self.label.updateImage() self._update_pg_window_lut() def on_max_slider_value_changed(self, value): self.screen_image.set_white_point_normalized(value) self.max_slider_edit.setText('{:g}'.format(self.screen_image.white_point)) self.label.updateImage() self._update_pg_window_lut() def set_slider_value(self, slider, value): slider.blockSignals(True) slider.setValue(int(self.screen_image.convert_native_value_to_normalized(value))) slider.blockSignals(False) class ImageViewingError(FlowError): pass tofu-0.12.0/tofu/genreco.py000066400000000000000000000733561423713721100155370ustar00rootroot00000000000000"""General projection-based reconstruction for tomographic/laminographic cone/parallel beam data sets. """ import copy import itertools import logging import os import time import numpy as np from multiprocessing.pool import ThreadPool from gi.repository import Ufo from .preprocess import create_preprocessing_pipeline from .util import (get_filtering_padding, get_reconstructed_cube_shape, get_reconstruction_regions, get_filenames, determine_shape, get_scarray_value, Vector) from .tasks import get_task, get_writer LOG = logging.getLogger(__name__) DTYPE_CL_SIZE = {'float': 4, 'double': 8, 'half': 2, 'uchar': 1, 'ushort': 2, 'uint': 4} def genreco(args): st = time.time() if is_output_single_file(args): try: import ufo.numpy except ImportError: LOG.error('You must install ufo-python-tools to be able to write single-file output') return if (args.energy is not None and args.propagation_distance is not None and not (args.projection_margin or args.disable_projection_crop)): LOG.warning('Phase retrieval without --projection-margin specification or ' '--disable-projection-crop may cause convolution artifacts') _fill_missing_args(args) _convert_angles_to_rad(args) set_projection_filter_scale(args) x_region, y_region, z_region = get_reconstruction_regions(args, store=True, dtype=float) vol_shape = get_reconstructed_cube_shape(x_region, y_region, z_region) bpp = DTYPE_CL_SIZE[args.store_type] num_voxels = vol_shape[0] * vol_shape[1] * vol_shape[2] vol_nbytes = num_voxels * bpp resources = [Ufo.Resources()] gpus = np.array(resources[0].get_gpu_nodes()) gpu_indices = np.array(args.gpus or list(range(len(gpus)))) if min(gpu_indices) < 0 or max(gpu_indices) > len(gpus) - 1: raise ValueError('--gpus contains invalid indices') gpus = gpus[gpu_indices] duration = 0 for i, gpu in enumerate(gpus): print('Max mem for {}: {:.2f} GB'.format(i, gpu.get_info(0) / 2. ** 30)) runs = make_runs(gpus, gpu_indices, x_region, y_region, z_region, bpp, slices_per_device=args.slices_per_device, slice_memory_coeff=args.slice_memory_coeff, data_splitting_policy=args.data_splitting_policy, num_gpu_threads=args.num_gpu_threads) for i in range(len(runs[0]) - 1): resources.append(Ufo.Resources()) LOG.info('Number of passes: %d', len(runs)) LOG.debug('GPUs and regions:') for regions in runs: LOG.debug('%s', str(regions)) for i, regions in enumerate(runs): duration += _run(resources, args, x_region, y_region, regions, i, vol_nbytes) num_gupdates = num_voxels * args.number * 1e-9 total_duration = time.time() - st LOG.debug('UFO duration: %.2f s', duration) LOG.debug('Total duration: %.2f s', total_duration) LOG.debug('UFO performance: %.2f GUPS', num_gupdates / duration) LOG.debug('Total performance: %.2f GUPS', num_gupdates / total_duration) def make_runs(gpus, gpu_indices, x_region, y_region, z_region, bpp, slices_per_device=None, slice_memory_coeff=0.8, data_splitting_policy='one', num_gpu_threads=1): gpu_indices = np.array(gpu_indices) def _add_region(runs, gpu_index, current, to_process, z_start, z_step): current_per_thread = current // num_gpu_threads for i in range(num_gpu_threads): if i + 1 == num_gpu_threads: current_per_thread += current % num_gpu_threads z_end = z_start + current_per_thread * z_step runs[-1].append((gpu_indices[gpu_index], [z_start, z_end, z_step])) z_start = z_end return z_start, z_end, to_process - current z_start, z_stop, z_step = z_region y_start, y_stop, y_step = y_region x_start, x_stop, x_step = x_region slice_width, slice_height, num_slices = get_reconstructed_cube_shape(x_region, y_region, z_region) if slices_per_device: slices_per_device = [slices_per_device for i in range(len(gpus))] else: slices_per_device = get_num_slices_per_gpu(gpus, slice_width, slice_height, bpp, slice_memory_coeff=slice_memory_coeff) max_slices_per_pass = sum(slices_per_device) if not max_slices_per_pass: raise RuntimeError('None of the available devices has enough memory to store any slices') num_full_passes = num_slices // max_slices_per_pass LOG.debug('Number of slices: %d', num_slices) LOG.debug('Slices per device %s', slices_per_device) LOG.debug('Maximum slices on all GPUs per pass: %d', max_slices_per_pass) LOG.debug('Number of passes with full workload: %d', num_slices // max_slices_per_pass) sorted_indices = np.argsort(slices_per_device)[-np.count_nonzero(slices_per_device):] runs = [] z_start = z_region[0] to_process = num_slices # Create passes where all GPUs are fully loaded for j in range(num_full_passes): runs.append([]) for i in sorted_indices: z_start, z_end, to_process = _add_region(runs, i, slices_per_device[i], to_process, z_start, z_step) if to_process: if data_splitting_policy == 'one': # Fill the last pass by maximizing the workload per GPU runs.append([]) for i in sorted_indices[::-1]: if not to_process: break current = min(slices_per_device[i], to_process) z_start, z_end, to_process = _add_region(runs, i, current, to_process, z_start, z_step) else: # Fill the last pass by maximizing the number of GPUs which will work num_gpus = len(sorted_indices) runs.append([]) for j, i in enumerate(sorted_indices): # Current GPU will either process the maximum number of slices it can. If the number # of slices per GPU based on even division between them cannot saturate the GPU, use # this number. This way the work will be split evenly between the GPUs. current = max(min(slices_per_device[i], (to_process - 1) // (num_gpus - j) + 1), 1) z_start, z_end, to_process = _add_region(runs, i, current, to_process, z_start, z_step) if not to_process: break return runs def get_num_slices_per_gpu(gpus, width, height, bpp, slice_memory_coeff=0.8): num_slices = [] slice_size = width * height * bpp for i, gpu in enumerate(gpus): max_mem = gpu.get_info(Ufo.GpuNodeInfo.GLOBAL_MEM_SIZE) num_slices.append(int(np.floor(max_mem * slice_memory_coeff / slice_size))) return num_slices def _run(resources, args, x_region, y_region, regions, run_number, vol_nbytes): """Execute one pass on all possible GPUs with slice ranges given by *regions*. Use separate thread per GPU and optimize the read projection regions. """ executors = [] for index in range(len(regions)): gpu_index, region = regions[index] region_index = run_number * len(resources) + index executors.append(Executor(resources[index], args, region, x_region, y_region, gpu_index, region_index)) def start_one(index): return executors[index].process() st = time.time() pool = ThreadPool(processes=len(regions)) result = pool.map_async(start_one, list(range(len(regions)))) if is_output_single_file(args): import tifffile bigtiff = vol_nbytes > 2 ** 32 - 1 LOG.debug('Writing BigTiff: %s', bigtiff) dirname = os.path.dirname(args.output) if dirname and not os.path.exists(dirname): os.makedirs(dirname) with tifffile.TiffWriter(args.output, append=run_number != 0, bigtiff=bigtiff) as writer: for executor in executors: executor.consume(writer) result.wait() return time.time() - st def setup_graph(args, graph, x_region, y_region, region, source=None, gpu=None, do_output=True, index=0, make_reader=True): backproject = get_task('general-backproject', processing_node=gpu) if do_output: if args.dry_run: sink = get_task('null', processing_node=gpu, download=True) else: sink = get_writer(args) sink.props.filename = '{}-{:>03}-%04i.tif'.format(args.output, index) width = args.width height = args.height if args.transpose_input: tmp = width width = height height = tmp if args.projection_filter != 'none' and args.projection_crop_after == 'backprojection': # Take projection padding into account padding = get_filtering_padding(width) args.center_position_x = [pos + padding / 2 for pos in args.center_position_x] if args.z_parameter == 'center-position-x': region = [region[0] + padding / 2, region[1] + padding / 2, region[2]] LOG.debug('center-position-x after padding: %g', args.center_position_x[0]) backproject.props.parameter = args.z_parameter if args.burst: backproject.props.burst = args.burst backproject.props.z = args.z backproject.props.region = region backproject.props.x_region = x_region backproject.props.y_region = y_region backproject.props.center_position_x = args.center_position_x backproject.props.center_position_z = args.center_position_z backproject.props.source_position_x = args.source_position_x backproject.props.source_position_y = args.source_position_y backproject.props.source_position_z = args.source_position_z backproject.props.detector_position_x = args.detector_position_x backproject.props.detector_position_y = args.detector_position_y backproject.props.detector_position_z = args.detector_position_z backproject.props.detector_angle_x = args.detector_angle_x backproject.props.detector_angle_y = args.detector_angle_y backproject.props.detector_angle_z = args.detector_angle_z backproject.props.axis_angle_x = args.axis_angle_x backproject.props.axis_angle_y = args.axis_angle_y backproject.props.axis_angle_z = args.axis_angle_z backproject.props.volume_angle_x = args.volume_angle_x backproject.props.volume_angle_y = args.volume_angle_y backproject.props.volume_angle_z = args.volume_angle_z backproject.props.num_projections = args.number backproject.props.compute_type = args.compute_type backproject.props.result_type = args.result_type backproject.props.store_type = args.store_type backproject.props.overall_angle = args.overall_angle backproject.props.addressing_mode = args.genreco_padding_mode backproject.props.gray_map_min = args.slice_gray_map[0] backproject.props.gray_map_max = args.slice_gray_map[1] source = create_preprocessing_pipeline(args, graph, source=source, processing_node=gpu, cone_beam_weight=not args.disable_cone_beam_weight, make_reader=make_reader) if source: graph.connect_nodes(source, backproject) else: source = backproject if do_output: graph.connect_nodes(backproject, sink) last = sink else: last = backproject return (source, last) def is_output_single_file(args): filename = args.output.lower() return not args.dry_run and (filename.endswith('.tif') or filename.endswith('.tiff')) def set_projection_filter_scale(args): is_parallel = np.all(np.isinf(args.source_position_y)) magnification = (args.source_position_y[0] - args.detector_position_y[0]) / \ args.source_position_y[0] args.projection_filter_scale = 1. if is_parallel: if np.any(np.array(args.axis_angle_x)): LOG.debug('Adjusting filter for parallel beam laminography') args.projection_filter_scale = 0.5 * np.cos(args.axis_angle_x[0]) else: args.projection_filter_scale = 0.5 args.projection_filter_scale /= magnification ** 2 if np.all(np.array(args.axis_angle_x) == 0): LOG.debug('Adjusting filter for cone beam tomography') args.projection_filter_scale /= magnification def _fill_missing_args(args): (width, height) = determine_shape(args, args.projections, store=False) if args.transpose_input: tmp = width width = height height = tmp args.center_position_x = (args.center_position_x or [width / 2.]) args.center_position_z = (args.center_position_z or [height / 2.]) if not args.overall_angle: args.overall_angle = 360. LOG.info('Overall angle not specified, using 360 deg') if not args.number: if len(args.axis_angle_z) > 1: LOG.debug("--number not specified, using length of --axis-angle-z: %d", len(args.axis_angle_z)) args.number = len(args.axis_angle_z) else: num_files = len(get_filenames(args.projections)) if not num_files: raise RuntimeError("No files found in `{}'".format(args.projections)) LOG.debug("--number not specified, using number of files matching " "--projections pattern: %d", num_files) args.number = num_files if args.dry_run: if not args.number: raise ValueError('--number must be specified by --dry-run') determine_shape(args, args.projections, store=True) LOG.info('Dummy data W x H x N: {} x {} x {}'.format(args.width, args.height, args.number)) return args def _convert_angles_to_rad(args): names = ['detector_angle', 'axis_angle', 'volume_angle'] coords = ['x', 'y', 'z'] angular_z_params = [x[0].replace('_', '-') + '-' + x[1] for x in itertools.product(names, coords)] args.overall_angle = np.deg2rad(args.overall_angle) if args.z_parameter in angular_z_params: LOG.debug('Converting z parameter values to radians') args.region = _convert_list_to_rad(args.region) for name in names: for coord in coords: full_name = name + '_' + coord values = getattr(args, full_name) setattr(args, full_name, _convert_list_to_rad(values)) def _convert_list_to_rad(values): return np.deg2rad(np.array(values)).tolist() def _are_values_equal(values): return np.all(np.array(values) == values[0]) class Executor(object): def __init__(self, resources, args, region, x_region, y_region, gpu_index, region_index): self.resources = resources self.args = args self.region = region self.gpu_index = gpu_index self.x_region = x_region self.y_region = y_region self.region_index = region_index self.single_file_output = is_output_single_file(self.args) self.output = Ufo.OutputTask() if self.single_file_output else None def process(self): scheduler = Ufo.FixedScheduler() scheduler.set_resources(self.resources) graph = Ufo.TaskGraph() gpu = scheduler.get_resources().get_gpu_nodes()[self.gpu_index] geometry = CTGeometry(self.args) if (len(self.args.center_position_z) == 1 and np.modf(self.args.center_position_z[0])[0] == 0 and geometry.is_simple_parallel_tomo): LOG.info('Simple tomography with integer z center, changing to center_position_z + 0.5 ' 'to avoid interpolation') geometry.args.center_position_z = (geometry.args.center_position_z[0] + 0.5,) if not self.args.disable_projection_crop: if not self.args.dry_run and (self.args.y or self.args.height or self.args.transpose_input): LOG.debug('--y or --height or --transpose-input specified, ' 'not optimizing projection region') else: geometry.optimize_args(region=self.region) opt_args = geometry.args if self.args.dry_run: source = get_task('dummy-data', number=self.args.number, width=self.args.width, height=self.args.height) else: source = None last = setup_graph(opt_args, graph, self.x_region, self.y_region, self.region, source=source, gpu=gpu, index=self.region_index, make_reader=True, do_output=not self.single_file_output)[-1] if self.single_file_output: graph.connect_nodes(last, self.output) LOG.debug('Device: %d, region: %s', self.gpu_index, self.region) scheduler.run(graph) return scheduler.props.time def consume(self, writer): import ufo.numpy for i in np.arange(*self.region): buf = self.output.get_output_buffer() writer.save(ufo.numpy.asarray(buf)) self.output.release_output_buffer(buf) class CTGeometry(object): def __init__(self, args): self.args = copy.deepcopy(args) determine_shape(self.args, self.args.projections, store=True) get_reconstruction_regions(self.args, store=True, dtype=float) self.args.center_position_x = (self.args.center_position_x or [self.args.width / 2.]) self.args.center_position_z = (self.args.center_position_z or [self.args.height / 2.]) @property def is_parallel(self): return np.all(np.isinf(self.args.source_position_y)) @property def is_detector_rotated(self): return (np.any(self.args.detector_angle_x) or np.any(self.args.detector_angle_y) or np.any(self.args.detector_angle_z)) @property def is_axis_rotated(self): return (np.any(self.args.axis_angle_x) or np.any(self.args.axis_angle_y) or np.any(self.args.axis_angle_z)) @property def is_volume_rotated(self): return (np.any(self.args.volume_angle_x) or np.any(self.args.volume_angle_y) or np.any(self.args.volume_angle_z)) @property def is_center_position_x_constant(self): return _are_values_equal(self.args.center_position_x) @property def is_center_position_z_constant(self): return _are_values_equal(self.args.center_position_z) @property def is_center_constant(self): return self.is_center_position_x_constant and self.is_center_position_z_constant @property def is_simple_parallel_tomo(self): return (not (self.is_axis_rotated or self.is_detector_rotated or self.is_volume_rotated) and self.is_parallel and self.is_center_constant) def optimize_args(self, region=None): xmin, ymin, xmax, ymax = self.compute_height(region=region) center_position_z = np.array(self.args.center_position_z) - ymin self.args.center_position_z = center_position_z.tolist() self.args.y = int(ymin) self.args.height = int(ymax - ymin) LOG.debug('Optimization for region: %s', region or self.args.region) LOG.debug('Optimized X: %d - %d, Z: %d - %d', xmin, xmax, ymin, ymax) LOG.debug('Optimized Z: %d', self.args.y) LOG.debug('Optimized height: %d', self.args.height) LOG.debug('Optimized center_position_z: %g - %g', self.args.center_position_z[0], self.args.center_position_z[-1]) def compute_height(self, region=None): extrema = [] if not region: region = self.args.region if self.is_simple_parallel_tomo: # Simple parallel beam tomography, thus compute only the horizontal crop at rotations # which are multiples of 45 degrees LOG.debug('Computing optimal projection region from 4 angles') projs_per_45 = self.args.number / self.args.overall_angle * np.pi / 4 stop = 4 if self.args.overall_angle <= np.pi else 8 indices = projs_per_45 * np.arange(1, stop, 2) indices = np.round(indices).astype(np.int).tolist() else: LOG.debug('Computing optimal projection region from all angles') indices = list(range(self.args.number)) for i in indices: extrema_0 = self._compute_one_parameter(region[0], i) extrema_1 = self._compute_one_parameter(region[1], i) extrema.append(extrema_0) extrema.append(extrema_1) minima = np.min(extrema, axis=0) maxima = np.max(extrema, axis=0) if maxima[-1] == minima[2]: # Don't let height be 0 maxima[-1] += 1 result = tuple(minima[::2]) + tuple(maxima[1::2]) return result def _compute_one_parameter(self, param_value, index): source_position = np.array([get_scarray_value(self.args.source_position_x, index), get_scarray_value(self.args.source_position_y, index), get_scarray_value(self.args.source_position_z, index)]) axis = Vector(x_angle=get_scarray_value(self.args.axis_angle_x, index), y_angle=get_scarray_value(self.args.axis_angle_y, index), z_angle=get_scarray_value(self.args.axis_angle_z, index), position=[get_scarray_value(self.args.center_position_x, index), 0, get_scarray_value(self.args.center_position_z, index)]) detector = Vector(x_angle=get_scarray_value(self.args.detector_angle_x, index), y_angle=get_scarray_value(self.args.detector_angle_y, index), z_angle=get_scarray_value(self.args.detector_angle_z, index), position=[get_scarray_value(self.args.detector_position_x, index), get_scarray_value(self.args.detector_position_y, index), get_scarray_value(self.args.detector_position_z, index)]) volume_angle = Vector(x_angle=get_scarray_value(self.args.volume_angle_x, index), y_angle=get_scarray_value(self.args.volume_angle_y, index), z_angle=get_scarray_value(self.args.volume_angle_z, index)) z = self.args.z if self.args.z_parameter == 'z': z = param_value elif self.args.z_parameter == 'axis-angle-x': axis.x_angle = param_value elif self.args.z_parameter == 'axis-angle-y': axis.y_angle = param_value elif self.args.z_parameter == 'axis-angle-z': axis.z_angle = param_value elif self.args.z_parameter == 'volume-angle-x': volume_angle.x_angle = param_value elif self.args.z_parameter == 'volume-angle-y': volume_angle.y_angle = param_value elif self.args.z_parameter == 'volume-angle-z': volume_angle.z_angle = param_value elif self.args.z_parameter == 'detector-angle-x': detector.x_angle = param_value elif self.args.z_parameter == 'detector-angle-y': detector.y_angle = param_value elif self.args.z_parameter == 'detector-angle-z': detector.z_angle = param_value elif self.args.z_parameter == 'detector-position-x': detector.position[0] = param_value elif self.args.z_parameter == 'detector-position-y': detector.position[1] = param_value elif self.args.z_parameter == 'detector-position-z': detector.position[2] = param_value elif self.args.z_parameter == 'source-position-x': source_position[0] = param_value elif self.args.z_parameter == 'source-position-y': source_position[1] = param_value elif self.args.z_parameter == 'source-position-z': source_position[2] = param_value elif self.args.z_parameter == 'center-position-x': axis.position[0] = param_value elif self.args.z_parameter == 'center-position-z': axis.position[2] = param_value else: raise RuntimeError("Unknown z parameter '{}'".format(self.args.z_parameter)) points = get_extrema(self.args.x_region, self.args.y_region, z) if self.args.z_parameter != 'z': points_upper = get_extrema(self.args.x_region, self.args.y_region, z + 1) points = np.hstack((points, points_upper)) tomo_angle = float(index) / self.args.number * self.args.overall_angle xe, ye = compute_detector_pixels(points, source_position, axis, volume_angle, detector, tomo_angle) return compute_detector_region(xe, ye, (self.args.height, self.args.width), overhead=self.args.projection_margin) def project(points, source, detector_normal, detector_offset): """Project *points* onto a detector.""" x, y, z = points source_extended = np.tile(source[:, np.newaxis], [1, points.shape[1]]) detector_normal_extended = np.tile(detector_normal[:, np.newaxis], [1, points.shape[1]]) denom = np.sum((points - source_extended) * detector_normal_extended, axis=0) if np.isinf(source[1]): # Parallel beam if np.any(detector_normal != np.array([0., -1, 0])): # Detector is not perpendicular, compute translation along the beam direction, # otherwise don't compute anything because voxels are mapped directly # to detector coordinates points[1, :] = - (detector_offset + detector_normal[0] * points[0, :] + detector_normal[2] * points[2, :]) / detector_normal[1] projected = points else: # Cone beam u = -(detector_offset + np.dot(source, detector_normal)) / denom u = np.tile(u, [3, 1]) projected = source_extended + (points - source_extended) * u return projected def compute_detector_pixels(points, source_position, axis, volume_rotation, detector, tomo_angle): """*points* are a list of points along x-direcion, thus the array has height 3. *source_position* is a 3-vector, *axis*, *volume_rotation* and *detector* are util.Vector instances. """ # Rotate the axis detector_normal = np.array((0, -1, 0), dtype=np.float) detector_normal = rotate_z(detector.z_angle, detector_normal) detector_normal = rotate_y(detector.y_angle, detector_normal) detector_normal = rotate_x(detector.x_angle, detector_normal) # Compute d from ax + by + cz + d = 0 detector_offset = -np.dot(detector.position, detector_normal) if np.isinf(source_position[1]): # Parallel beam voxels = points else: # Apply magnification voxels = -points * source_position[1] / (detector.position[1] - source_position[1]) # Rotate the volume voxels = rotate_z(volume_rotation.z_angle, voxels) voxels = rotate_y(volume_rotation.y_angle, voxels) voxels = rotate_x(volume_rotation.x_angle, voxels) # Rotate around the axis voxels = rotate_z(tomo_angle, voxels) # Rotate the volume voxels = rotate_z(axis.z_angle, voxels) voxels = rotate_y(axis.y_angle, voxels) voxels = rotate_x(axis.x_angle, voxels) # Get the projected pixel projected = project(voxels, source_position, detector_normal, detector_offset) if np.any(detector_normal != np.array([0., -1, 0])): # Detector is not perpendicular projected -= np.array([detector.position]).T # Reverse rotation => reverse order of transformation matrices and negative angles projected = rotate_x(-detector.x_angle, projected) projected = rotate_y(-detector.y_angle, projected) projected = rotate_z(-detector.z_angle, projected) x = projected[0, :] + axis.position[0] - 0.5 y = projected[2, :] + axis.position[2] - 0.5 return x, y def compute_detector_region(x, y, shape, overhead=2): """*overhead* specifies how much margin is taken into account around the computed area.""" def _compute_outlier(extremum_func, values): if extremum_func == min: round_func = np.floor sgn = -1 else: round_func = np.ceil sgn = +1 return int(round_func(extremum_func(values)) + sgn * overhead) x_min = min(shape[1], max(0, _compute_outlier(min, x))) y_min = min(shape[0], max(0, _compute_outlier(min, y))) x_max = max(0, min(shape[1], _compute_outlier(max, x))) y_max = max(0, min(shape[0], _compute_outlier(max, y))) return (x_min, x_max, y_min, y_max) def get_extrema(x_region, y_region, z): def get_extrema(region): return (region[0], region[1]) product = itertools.product(get_extrema(x_region), get_extrema(y_region), [z]) return np.array(list(product), dtype=np.float).T.copy() def rotate_x(angle, point): cos = np.cos(angle) sin = np.sin(angle) matrix = np.identity(3) matrix[1, 1] = cos matrix[1, 2] = -sin matrix[2, 1] = sin matrix[2, 2] = cos return np.dot(matrix, point) def rotate_y(angle, point): cos = np.cos(angle) sin = np.sin(angle) matrix = np.identity(3) matrix[0, 0] = cos matrix[0, 2] = sin matrix[2, 0] = -sin matrix[2, 2] = cos return np.dot(matrix, point) def rotate_z(angle, point): cos = np.cos(angle) sin = np.sin(angle) matrix = np.identity(3) matrix[0, 0] = cos matrix[0, 1] = -sin matrix[1, 0] = sin matrix[1, 1] = cos return np.dot(matrix, point) tofu-0.12.0/tofu/gui.py000066400000000000000000000545711423713721100146770ustar00rootroot00000000000000import sys import os import logging import numpy as np import tifffile import pkg_resources from argparse import ArgumentParser from contextlib import contextmanager from . import reco, config, util, __version__ try: import tofu.vis.qt from PyQt4 import QtGui, QtCore, uic except ImportError: raise ImportError("Cannot import modules for GUI, please install PyQt4 and pyqtgraph") LOG = logging.getLogger(__name__) def set_last_dir(path, line_edit, last_dir): if os.path.exists(str(path)): line_edit.clear() line_edit.setText(path) last_dir = str(line_edit.text()) return last_dir def get_filtered_filenames(path, exts=['.tif', '.edf']): result = [] try: for ext in exts: result += [os.path.join(path, f) for f in os.listdir(path) if f.endswith(ext)] except OSError: return [] return sorted(result) @contextmanager def spinning_cursor(): QtGui.QApplication.setOverrideCursor(QtGui.QCursor(QtCore.Qt.WaitCursor)) yield QtGui.QApplication.restoreOverrideCursor() class CallableHandler(logging.Handler): def __init__(self, func): logging.Handler.__init__(self) self.func = func def emit(self, record): self.func(self.format(record)) class ApplicationWindow(QtGui.QMainWindow): def __init__(self, app, params): QtGui.QMainWindow.__init__(self) self.params = params self.app = app ui_file = pkg_resources.resource_filename(__name__, 'gui.ui') self.ui = uic.loadUi(ui_file, self) self.ui.show() self.ui.setAttribute(QtCore.Qt.WA_DeleteOnClose) self.ui.tab_widget.setCurrentIndex(0) self.ui.slice_dock.setVisible(False) self.ui.volume_dock.setVisible(False) self.ui.axis_view_widget.setVisible(False) self.slice_viewer = None self.volume_viewer = None self.overlap_viewer = tofu.vis.qt.OverlapViewer() self.get_values_from_params() log_handler = CallableHandler(self.on_log_record) log_handler.setLevel(logging.DEBUG) log_handler.setFormatter(logging.Formatter('%(name)s: %(message)s')) root_logger = logging.getLogger('') root_logger.setLevel(logging.DEBUG) root_logger.handlers = [log_handler] self.ui.input_path_button.setToolTip('Path to projections or sinograms') self.ui.proj_button.setToolTip('Denote if path contains projections') self.ui.y_step.setToolTip(self.get_help('reading', 'y-step')) self.ui.method_box.setToolTip(self.get_help('tomographic-reconstruction', 'method')) self.ui.axis_spin.setToolTip(self.get_help('tomographic-reconstruction', 'axis')) self.ui.angle_step.setToolTip(self.get_help('reconstruction', 'angle')) self.ui.angle_offset.setToolTip(self.get_help('tomographic-reconstruction', 'offset')) self.ui.oversampling.setToolTip(self.get_help('dfi', 'oversampling')) self.ui.iterations_sart.setToolTip(self.get_help('ir', 'num-iterations')) self.ui.relaxation.setToolTip(self.get_help('sart', 'relaxation-factor')) self.ui.output_path_button.setToolTip(self.get_help('general', 'output')) self.ui.ffc_box.setToolTip(self.get_help('gui', 'ffc-correction')) self.ui.interpolate_button.setToolTip('Interpolate between two sets of flat fields') self.ui.darks_path_button.setToolTip(self.get_help('flat-correction', 'darks')) self.ui.flats_path_button.setToolTip(self.get_help('flat-correction', 'flats')) self.ui.flats2_path_button.setToolTip(self.get_help('flat-correction', 'flats2')) self.ui.path_button_0.setToolTip(self.get_help('gui', 'deg0')) self.ui.path_button_180.setToolTip(self.get_help('gui', 'deg180')) self.ui.input_path_button.clicked.connect(self.on_input_path_clicked) self.ui.sino_button.clicked.connect(self.on_sino_button_clicked) self.ui.proj_button.clicked.connect(self.on_proj_button_clicked) self.ui.region_box.clicked.connect(self.on_region_box_clicked) self.ui.method_box.currentIndexChanged.connect(self.change_method) self.ui.axis_spin.valueChanged.connect(self.change_axis_spin) self.ui.angle_step.valueChanged.connect(self.change_angle_step) self.ui.output_path_button.clicked.connect(self.on_output_path_clicked) self.ui.ffc_box.clicked.connect(self.on_ffc_box_clicked) self.ui.interpolate_button.clicked.connect(self.on_interpolate_button_clicked) self.ui.darks_path_button.clicked.connect(self.on_darks_path_clicked) self.ui.flats_path_button.clicked.connect(self.on_flats_path_clicked) self.ui.flats2_path_button.clicked.connect(self.on_flats2_path_clicked) self.ui.ffc_options.currentIndexChanged.connect(self.change_ffc_options) self.ui.reco_button.clicked.connect(self.on_reconstruct) self.ui.path_button_0.clicked.connect(self.on_path_0_clicked) self.ui.path_button_180.clicked.connect(self.on_path_180_clicked) self.ui.show_slices_button.clicked.connect(self.on_show_slices_clicked) self.ui.show_volume_button.clicked.connect(self.on_show_volume_clicked) self.ui.run_button.clicked.connect(self.on_compute_center) self.ui.save_action.triggered.connect(self.on_save_as) self.ui.clear_action.triggered.connect(self.on_clear) self.ui.clear_output_dir_action.triggered.connect(self.on_clear_output_dir_clicked) self.ui.open_action.triggered.connect(self.on_open_from) self.ui.close_action.triggered.connect(self.close) self.ui.about_action.triggered.connect(self.on_about) self.ui.extrema_checkbox.clicked.connect(self.on_remove_extrema_clicked) self.ui.overlap_opt.currentIndexChanged.connect(self.on_overlap_opt_changed) self.ui.input_path_line.textChanged.connect(self.on_input_path_changed) self.ui.y_step.valueChanged.connect(lambda value: self.change_value('y_step', value)) self.ui.angle_offset.valueChanged.connect(lambda value: self.change_value('offset', value)) self.ui.oversampling.valueChanged.connect(lambda value: self.change_value('oversampling', value)) self.ui.iterations_sart.valueChanged.connect(lambda value: self.change_value('num_iterations', value)) self.ui.relaxation.valueChanged.connect(lambda value: self.change_value('relaxation_factor', value)) self.ui.output_path_line.textChanged.connect(lambda value: self.change_value('output', str(self.ui.output_path_line.text()))) self.ui.darks_path_line.textChanged.connect(lambda value: self.change_value('darks', str(self.ui.darks_path_line.text()))) self.ui.flats_path_line.textChanged.connect(lambda value: self.change_value('flats', str(self.ui.flats_path_line.text()))) self.ui.flats2_path_line.textChanged.connect(lambda value: self.change_value('flats2', str(self.ui.flats2_path_line.text()))) self.ui.fix_naninf_box.clicked.connect(lambda value: self.change_value('fix_nan_and_inf', self.ui.fix_naninf_box.isChecked())) self.ui.absorptivity_box.clicked.connect(lambda value: self.change_value('absorptivity', self.ui.absorptivity_box.isChecked())) self.ui.path_line_0.textChanged.connect(lambda value: self.change_value('deg0', str(self.ui.path_line_0.text()))) self.ui.path_line_180.textChanged.connect(lambda value: self.change_value('deg180', str(self.ui.path_line_180.text()))) self.ui.overlap_layout.addWidget(self.overlap_viewer) self.overlap_viewer.slider.valueChanged.connect(self.on_axis_slider_changed) def on_log_record(self, record): self.ui.text_browser.append(record) def get_values_from_params(self): self.ui.input_path_line.setText(self.params.sinograms or self.params.projections or '.') self.ui.output_path_line.setText(self.params.output or '') self.ui.darks_path_line.setText(self.params.darks or '') self.ui.flats_path_line.setText(self.params.flats or '') self.ui.flats2_path_line.setText(self.params.flats2 or '') self.ui.path_line_0.setText(self.params.deg0) self.ui.path_line_180.setText(self.params.deg180) self.ui.y_step.setValue(self.params.y_step if self.params.y_step else 1) self.ui.axis_spin.setValue(self.params.axis if self.params.axis else 0.0) self.ui.angle_step.setValue(self.params.angle if self.params.angle else 0.0) self.ui.angle_offset.setValue(self.params.offset if self.params.offset else 0.0) self.ui.oversampling.setValue(self.params.oversampling if self.params.oversampling else 0) self.ui.iterations_sart.setValue(self.params.num_iterations if self.params.num_iterations else 0) self.ui.relaxation.setValue(self.params.relaxation_factor if self.params.relaxation_factor else 0.0) if self.params.projections is not None: self.ui.proj_button.setChecked(True) self.ui.sino_button.setChecked(False) self.on_proj_button_clicked() else: self.ui.proj_button.setChecked(False) self.ui.sino_button.setChecked(True) self.on_sino_button_clicked() if self.params.method == "fbp": self.ui.method_box.setCurrentIndex(0) elif self.params.method == "dfi": self.ui.method_box.setCurrentIndex(1) elif self.params.method == "sart": self.ui.method_box.setCurrentIndex(2) self.change_method() if self.params.y_step > 1 and self.sino_button.isChecked(): self.ui.region_box.setChecked(True) else: self.ui.region_box.setChecked(False) self.ui.on_region_box_clicked() ffc_enabled = bool(self.params.flats) and bool(self.params.darks) and self.proj_button.isChecked() self.ui.ffc_box.setChecked(ffc_enabled) self.ui.preprocessing_container.setVisible(ffc_enabled) self.ui.interpolate_button.setChecked(bool(self.params.flats2) and ffc_enabled) self.ui.fix_naninf_box.setChecked(self.params.fix_nan_and_inf) self.ui.absorptivity_box.setChecked(self.params.absorptivity) if self.params.reduction_mode.lower() == "average": self.ui.ffc_options.setCurrentIndex(0) else: self.ui.ffc_options.setCurrentIndex(1) def change_method(self): self.params.method = str(self.ui.method_box.currentText()).lower() is_dfi = self.params.method == 'dfi' is_sart = self.params.method == 'sart' for w in (self.ui.oversampling_label, self.ui.oversampling): w.setVisible(is_dfi) for w in (self.ui.relaxation, self.ui.relaxation_label, self.ui.iterations_sart, self.ui.iterations_sart_label): w.setVisible(is_sart) def get_help(self, section, name): help = config.SECTIONS[section][name]['help'] return help def change_value(self, name, value): setattr(self.params, name, value) def on_sino_button_clicked(self): self.on_input_path_changed() self.ui.ffc_box.setEnabled(False) self.ui.preprocessing_container.setVisible(False) def on_proj_button_clicked(self): self.on_input_path_changed() self.ui.ffc_box.setEnabled(True) self.ui.preprocessing_container.setVisible(self.ffc_box.isChecked()) self.ui.region_box.setEnabled(False) self.ui.region_box.setChecked(False) self.on_region_box_clicked() def on_region_box_clicked(self): self.ui.y_step.setEnabled(self.ui.region_box.isChecked()) if self.ui.region_box.isChecked(): self.params.y_step = self.ui.y_step.value() else: self.params.y_step = 1 def on_input_path_changed(self): if self.ui.sino_button.isChecked(): self.params.sinograms = str(self.ui.input_path_line.text()) self.params.projections = None else: self.params.sinograms = None self.params.projections = str(self.ui.input_path_line.text()) def on_input_path_clicked(self, checked): directory = self.params.projections or self.params.sinograms path = self.get_path(directory, self.params.last_dir) self.params.last_dir = set_last_dir(path, self.ui.input_path_line, self.params.last_dir) def change_axis_spin(self): if self.ui.axis_spin.value() == 0: self.params.axis = None else: self.params.axis = self.ui.axis_spin.value() def change_angle_step(self): if self.ui.angle_step.value() == 0: self.params.angle = None else: self.params.angle = self.ui.angle_step.value() def on_output_path_clicked(self, checked): path = self.get_path(self.params.output, self.params.last_dir) self.params.last_dir = set_last_dir(path, self.ui.output_path_line, self.params.last_dir) def on_clear_output_dir_clicked(self): with spinning_cursor(): output_absfiles = get_filtered_filenames(str(self.ui.output_path_line.text())) for f in output_absfiles: os.remove(f) def on_ffc_box_clicked(self): checked = self.ui.ffc_box.isChecked() self.ui.preprocessing_container.setVisible(checked) self.params.ffc_correction = checked def on_interpolate_button_clicked(self): checked = self.ui.interpolate_button.isChecked() self.ui.flats2_path_line.setEnabled(checked) self.ui.flats2_path_button.setEnabled(checked) def change_ffc_options(self): self.params.reduction_mode = str(self.ui.ffc_options.currentText()).lower() def on_darks_path_clicked(self, checked): path = self.get_path(self.params.darks, self.params.last_dir) self.params.last_dir = set_last_dir(path, self.ui.darks_path_line, self.params.last_dir) def on_flats_path_clicked(self, checked): path = self.get_path(self.params.flats, self.params.last_dir) self.params.last_dir = set_last_dir(path, self.ui.flats_path_line, self.params.last_dir) def on_flats2_path_clicked(self, checked): path = self.get_path(self.params.flats2, self.params.last_dir) self.params.last_dir = set_last_dir(path, self.ui.flats2_path_line, self.params.last_dir) def get_path(self, directory, last_dir): return QtGui.QFileDialog.getExistingDirectory(self, '.', last_dir or directory) def get_filename(self, directory, last_dir): return QtGui.QFileDialog.getOpenFileName(self, '.', last_dir or directory) def on_path_0_clicked(self, checked): path = self.get_filename(self.params.deg0, self.params.last_dir) self.params.last_dir = set_last_dir(path, self.ui.path_line_0, self.params.last_dir) def on_path_180_clicked(self, checked): path = self.get_filename(self.params.deg180, self.params.last_dir) self.params.last_dir = set_last_dir(path, self.ui.path_line_180, self.params.last_dir) def on_open_from(self): config_file = QtGui.QFileDialog.getOpenFileName(self, 'Open ...', self.params.last_dir) parser = ArgumentParser() params = config.Params(sections=config.TOMO_PARAMS + ('gui',)) parser = params.add_arguments(parser) self.params = parser.parse_known_args(config.config_to_list(config_name=config_file))[0] self.get_values_from_params() def on_about(self): message = "GUI is part of ufo-reconstruct {}.".format(__version__) QtGui.QMessageBox.about(self, "About ufo-reconstruct", message) def on_save_as(self): if os.path.exists(self.params.last_dir): config_file = str(self.params.last_dir + "/reco.conf") else: config_file = str(os.getenv('HOME') + "reco.conf") save_config = QtGui.QFileDialog.getSaveFileName(self, 'Save as ...', config_file) if save_config: sections = config.TOMO_PARAMS + ('gui',) config.write(save_config, args=self.params, sections=sections) def on_clear(self): self.ui.axis_view_widget.setVisible(False) self.ui.input_path_line.setText('.') self.ui.output_path_line.setText('.') self.ui.darks_path_line.setText('.') self.ui.flats_path_line.setText('.') self.ui.flats2_path_line.setText('.') self.ui.path_line_0.setText('.') self.ui.path_line_180.setText('.') self.ui.fix_naninf_box.setChecked(True) self.ui.absorptivity_box.setChecked(True) self.ui.sino_button.setChecked(True) self.ui.proj_button.setChecked(False) self.ui.region_box.setChecked(False) self.ui.ffc_box.setChecked(False) self.ui.interpolate_button.setChecked(False) self.ui.y_step.setValue(1) self.ui.axis_spin.setValue(0) self.ui.angle_step.setValue(0) self.ui.angle_offset.setValue(0) self.ui.oversampling.setValue(0) self.ui.ffc_options.setCurrentIndex(0) self.ui.text_browser.clear() self.ui.method_box.setCurrentIndex(0) self.params.enable_cropping = False self.params.reduction_mode = "average" self.params.fix_nan_and_inf = True self.params.absorptivity = True self.params.show_2d = False self.params.show_3d = False self.params.angle = None self.params.axis = None self.on_region_box_clicked() self.on_ffc_box_clicked() self.on_interpolate_button_clicked() def on_reconstruct(self): with spinning_cursor(): self.ui.centralWidget.setEnabled(False) self.repaint() self.app.processEvents() input_images = get_filtered_filenames(str(self.ui.input_path_line.text())) if not input_images: self.gui_warn("No data found in {}".format(str(self.ui.input_path_line.text()))) self.ui.centralWidget.setEnabled(True) return shape = util.get_image_shape(input_images[0]) self.params.width = shape[-1] self.params.height = shape[-2] self.params.ffc_correction = self.params.ffc_correction and self.ui.proj_button.isChecked() if not (self.params.output.endswith('.tif') or self.params.output.endswith('.tiff')): self.params.output = os.path.join(self.params.output, 'slice-%05i.tif') if self.params.y_step > 1: self.params.angle *= self.params.y_step if self.params.ffc_correction: flats_files = get_filtered_filenames(str(self.ui.flats_path_line.text())) self.params.num_flats = len(flats_files) else: self.params.num_flats = 0 self.params.darks = None self.params.flats = None self.params.flats2 = self.ui.flats2_path_line.text() if self.ui.interpolate_button.isChecked() else '' self.params.oversampling = self.ui.oversampling.value() if self.params.method == 'dfi' else None if self.params.method == 'sart': self.params.max_iterations = self.ui.iterations_sart.value() self.params.relaxation_factor = self.ui.relaxation.value() if self.params.angle is None: self.gui_warn("Missing argument for Angle step (rad)") else: try: reco.tomo(self.params) except Exception as e: self.gui_warn(str(e)) self.ui.centralWidget.setEnabled(True) self.params.angle = self.ui.angle_step.value() def on_show_slices_clicked(self): path = str(self.ui.output_path_line.text()) filenames = get_filtered_filenames(path) if not self.slice_viewer: self.slice_viewer = tofu.vis.qt.ImageViewer(filenames) self.slice_dock.setWidget(self.slice_viewer) self.ui.slice_dock.setVisible(True) else: self.slice_viewer.load_files(filenames) def on_show_volume_clicked(self): if not self.volume_viewer: step = int(self.ui.reduction_box.currentText()) self.volume_viewer = tofu.vis.qt.VolumeViewer(parent=self, step=step) self.volume_dock.setWidget(self.volume_viewer) self.ui.volume_dock.setVisible(True) path = str(self.ui.output_path_line.text()) filenames = get_filtered_filenames(path) self.volume_viewer.load_files(filenames) def on_compute_center(self): first_name = str(self.ui.path_line_0.text()) second_name = str(self.ui.path_line_180.text()) with tifffile.TiffFile(first_name) as tif: first = tif.pages[0].asarray().astype(np.float) with tifffile.TiffFile(second_name) as tif: second = tif.pages[-1].asarray().astype(np.float) if self.params.ffc_correction: # FIXME: we should of course use the pipelines we have ... flat_files = get_filtered_filenames(str(self.ui.flats_path_line.text())) dark_files = get_filtered_filenames(str(self.ui.darks_path_line.text())) flats = np.array([tifffile.TiffFile(x).asarray().astype(np.float) for x in flat_files]) darks = np.array([tifffile.TiffFile(x).asarray().astype(np.float) for x in dark_files]) dark = np.mean(darks, axis=0) flat = np.mean(flats, axis=0) - dark first = (first - dark) / flat second = (second - dark) / flat self.axis = reco.compute_rotation_axis(first, second) self.height, self.width = first.shape w2 = self.width / 2.0 position = w2 + (w2 - self.axis) * 2.0 self.overlap_viewer.set_images(first, second) self.overlap_viewer.set_position(position) self.ui.img_size.setText('width = {} | height = {}'.format(self.width, self.height)) def on_remove_extrema_clicked(self, val): self.ui.overlap_viewer.remove_extrema = val def on_overlap_opt_changed(self, index): self.ui.overlap_viewer.subtract = index == 0 self.ui.overlap_viewer.update_image() def on_axis_slider_changed(self): val = self.overlap_viewer.slider.value() w2 = self.width / 2.0 self.axis = w2 + (w2 - val) / 2 self.ui.axis_num.setText('{} px'.format(self.axis)) self.ui.axis_spin.setValue(self.axis) def gui_warn(self, message): QtGui.QMessageBox.warning(self, "Warning", message) def main(params): app = QtGui.QApplication(sys.argv) ApplicationWindow(app, params) sys.exit(app.exec_()) tofu-0.12.0/tofu/gui.ui000066400000000000000000001377761423713721100146750ustar00rootroot00000000000000 mainWindow 0 0 1018 1081 0 0 Tomoviewer true 0 0 541 761 0 0 530 0 PreferDefault true Qt::TabFocus Qt::LeftToRight 1 true 0 0 0 0 Reconstruction Input 0 0 Projections true 0 0 Sinograms true true false 0 0 1 500 0 0 false false Region (y-step): Qt::Horizontal 40 20 0 0 Do flat-field correction 0 0 Path: 0 0 0 0 Browse … 0 0 Flat-field correction 0 0 Average Median Options Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter Use absorptivity Remove NaN and Inf 0 0 Interpolate Qt::Horizontal 40 20 Method: Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter true Darks: Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter true true Browse … true 0 Flats: Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter Last flats: Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter true true Browse … Browse … Reconstruction 6 0 0 Angle step (rad): Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter 0 0 50 false Method: Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter 0 0 50 false FBP DFI SART 0 0 10 Qt::Horizontal 40 20 0 0 Angle offset (rad): Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter 0 0 Reconstruct 0 0 Axis (pixel): Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter 0 0 8192.000000000000000 0 0 10 0 0 Max iterations: Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter 0 0 0 Relaxation factor: 0.000000000000000 Qt::Horizontal 40 20 true 99 0 true 0 0 Oversampling: Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter Output 0 0 Path: 0 0 0 0 Browse ... Reduction: 1 1 2 4 8 Qt::Horizontal 40 20 Show Volume Show Slices 0 0 Log 0 0 QFrame::Sunken 0 0 0 Center of rotation 0 0 Input Options: Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter Method: Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter 0 0 Browse ... 0 0 180° projection: Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter 0 0 Browse ... 0 0 In case of multi-page input, last image in the file is used Remove extrema 0 0 In case of multi-page input, first image in the file is used 0 0 0° projection: Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter 0 0 Subtraction overlap Addition overlap 0 0 Run Output 0 0 Center: Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter 0 0 Size: Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter 0 0 0 0 75 true 0 0 0 0 0 0 0 0 0 1018 22 0 0 0 0 File 0 0 Edit Help QDockWidget::DockWidgetFloatable|QDockWidget::DockWidgetMovable 2 0 0 QDockWidget::DockWidgetFloatable|QDockWidget::DockWidgetMovable 2 Open Save as... Open ... Qt::WindowShortcut Save as ... Quit Ctrl+Q Clear Remove old slices Remove old slices in output directory About tofu-0.12.0/tofu/lamino.py000066400000000000000000000212451423713721100153620ustar00rootroot00000000000000"""Laminographic reconstruction.""" import logging import numpy as np from multiprocessing import Queue, Process from tofu.preprocess import create_preprocessing_pipeline from tofu.util import (get_filtering_padding, determine_shape, get_filenames, get_reconstruction_regions, get_reconstructed_cube_shape) from tofu.tasks import get_task, get_writer LOG = logging.getLogger(__name__) def lamino(params): """Laminographic reconstruction utilizing all GPUs.""" LOG.info('Z parameter: {}'.format(params.z_parameter)) prepare_angular_arguments(params) params.projection_filter_scale = np.sin(np.deg2rad(params.lamino_angle)) # For now we need to make a workaround for the memory leak, which means we need to execute # the passes in separate processes to clean up the low level code. For that we also need to # call the region-splitting in a separate function. # TODO: Simplify after the memory leak fix! queue = Queue() proc = Process(target=_create_runs, args=(params, queue,)) proc.start() proc.join() x_region, y_region, regions, num_gpus = queue.get() for i in range(0, len(regions), num_gpus): z_subregion = regions[i:min(i + num_gpus, len(regions))] LOG.info('Computing slices {}..{}'.format(z_subregion[0][0], z_subregion[-1][1])) proc = Process(target=_run, args=(params, x_region, y_region, z_subregion, i // num_gpus)) proc.start() proc.join() def prepare_angular_arguments(params): if not params.overall_angle: params.overall_angle = 360. LOG.info('Overall angle not specified, using 360 deg') if not params.angle: if params.dry_run: if not params.number: raise ValueError('--number must be specified by --dry-run') num_files = params.number else: num_files = len(get_filenames(params.projections)) if not num_files: raise RuntimeError("No files found in `{}'".format(params.projections)) params.angle = params.overall_angle / num_files * params.step LOG.info('Angle not specified, calculating from ' + '{} projections and step {}: {} deg'.format(num_files, params.step, params.angle)) determine_shape(params, params.projections, store=True) if not params.number: params.number = int(np.round(np.abs(params.overall_angle / params.angle))) if params.dry_run: LOG.info('Dummy data W x H x N: {} x {} x {}'.format(params.width, params.height, params.number)) def _create_runs(params, queue): """Workaround function to get the number of gpus and compute regions. gi.repository must always be called in a separate process, otherwise the resources return None gpus. """ #TODO: remove the whole function after memory leak fix! from gi.repository import Ufo scheduler = Ufo.FixedScheduler() gpus = scheduler.get_resources().get_gpu_nodes() num_gpus = len(gpus) x_region, y_region, regions = _split_regions(params, gpus) LOG.info('Using {} GPUs in {} passes'.format(min(len(regions), num_gpus), len(regions))) queue.put((x_region, y_region, regions, num_gpus)) def _run(params, x_region, y_region, regions, index): """Execute one pass on all possible GPUs with slice ranges given by *regions*.""" from gi.repository import Ufo pm = Ufo.PluginManager() graph = Ufo.TaskGraph() scheduler = Ufo.FixedScheduler() gpus = scheduler.get_resources().get_gpu_nodes() num_gpus = len(gpus) broadcast = Ufo.CopyTask() source = _setup_source(params, pm, graph) graph.connect_nodes(source, broadcast) for i, region in enumerate(regions): subindex = index * num_gpus + i _setup_graph(pm, graph, subindex, x_region, y_region, region, params, broadcast, gpu=gpus[i]) scheduler.run(graph) duration = scheduler.props.time LOG.info('Execution time: {} s'.format(duration)) return duration def _setup_source(params, pm, graph): from tofu.preprocess import create_flat_correct_pipeline from tofu.util import set_node_props, setup_read_task if params.dry_run: source = pm.get_task('dummy-data') source.props.number = params.number source.props.width = params.width source.props.height = params.height elif params.darks and params.flats: source = create_flat_correct_pipeline(params, graph) else: source = pm.get_task('read') set_node_props(source, params) setup_read_task(source, params.projections, params) return source def _setup_graph(pm, graph, index, x_region, y_region, region, params, source, gpu=None): backproject = get_task('lamino-backproject', processing_node=gpu) slicer = get_task('slice', processing_node=gpu) writer = get_writer(params) if not params.dry_run: writer.props.filename = '{}-{:>03}-%04i.tif'.format(params.output, index) # parameters backproject.props.num_projections = params.number backproject.props.overall_angle = np.deg2rad(params.overall_angle) backproject.props.lamino_angle = np.deg2rad(params.lamino_angle) backproject.props.roll_angle = np.deg2rad(params.roll_angle) backproject.props.x_region = x_region backproject.props.y_region = y_region backproject.props.z = params.z backproject.props.addressing_mode = params.lamino_padding_mode backproject.props.parameter = params.z_parameter if params.projection_crop_after == 'backprojection': padding = get_filtering_padding(params.width) else: padding = 0 if params.z_parameter in ['lamino-angle', 'roll-angle']: region = [np.deg2rad(reg) for reg in region] if params.z_parameter == 'x-center': # Take projection padding into account region = [region[0] + padding / 2, region[1] + padding / 2, region[2]] backproject.props.region = region backproject.props.center = (params.axis[0] + padding / 2, params.axis[1]) LOG.debug('x center after padding: %g', backproject.props.center[0]) graph.connect_nodes(backproject, slicer) graph.connect_nodes(slicer, writer) if params.only_bp: first = backproject graph.connect_nodes(source, backproject) else: first = create_preprocessing_pipeline(params, graph, source=source, processing_node=gpu) graph.connect_nodes(first, backproject) return first def _split_regions(params, gpus): """Split processing between *gpus* by specifying the number of slices processed per GPU.""" x_region, y_region, z_region = get_reconstruction_regions(params) z_start, z_stop, z_step = z_region y_start, y_stop, y_step = y_region x_start, x_stop, x_step = x_region slice_width, slice_height, num_slices = get_reconstructed_cube_shape(x_region, y_region, z_region) if params.slices_per_device: num_slices_per_gpu = params.slices_per_device else: num_slices_per_gpu = _compute_num_slices(gpus, slice_width, slice_height) if num_slices_per_gpu > num_slices: num_slices_per_gpu = num_slices LOG.info('Using {} slices per GPU'.format(num_slices_per_gpu)) z_starts = np.arange(z_start, z_stop, z_step * num_slices_per_gpu) regions = [] for start in z_starts: regions.append((start, min(z_stop, start + z_step * num_slices_per_gpu), z_step)) return x_region, y_region, regions def _compute_num_slices(gpus, width, height): """Determine number of slices which can be calculated per-device based on *gpus*, slice *width* and *height*. """ from gi.repository import Ufo # Make sure the double buffering works with room for intermediate steps # TODO: compute this precisely safety_coeff = 3. # Use the weakest one, if heterogenous systems emerge, measure the performance and # reconsider memories = [gpu.get_info(Ufo.GpuNodeInfo.GLOBAL_MEM_SIZE) for gpu in gpus] i = np.argmin(memories) max_allocatable = gpus[i].get_info(Ufo.GpuNodeInfo.MAX_MEM_ALLOC_SIZE) if max_allocatable * safety_coeff <= memories[i]: # Don't waste resources max_memory = max_allocatable else: max_memory = memories[i] / safety_coeff if max_memory > 2 ** 32: # Current NVIDIA implementation allows only 4 GB max_memory = 2 ** 32 max_memory /= safety_coeff num_slices = int(np.floor(max_memory / (width * height * 4))) LOG.info('GPU memory used per GPU: {:.2f} GB'.format(max_memory / 2. ** 30)) return num_slices tofu-0.12.0/tofu/preprocess.py000066400000000000000000000357121423713721100162740ustar00rootroot00000000000000"""Flat field correction.""" import sys import logging from gi.repository import Ufo from tofu.util import (get_filenames, set_node_props, make_subargs, determine_shape, setup_read_task, setup_padding, next_power_of_two) from tofu.tasks import get_task, get_writer LOG = logging.getLogger(__name__) def create_flat_correct_pipeline(args, graph, processing_node=None): """ Create flat field correction pipeline. All the settings are provided in *args*. *graph* is used for making the connections. Returns the flat field correction task which can be used for further pipelining. """ pm = Ufo.PluginManager() if args.projections is None or args.flats is None or args.darks is None: raise RuntimeError("You must specify --projections, --flats and --darks.") reader = get_task('read') dark_reader = get_task('read') flat_before_reader = get_task('read') ffc = get_task('flat-field-correct', processing_node=processing_node, dark_scale=args.dark_scale, flat_scale=args.flat_scale, absorption_correct=args.absorptivity, fix_nan_and_inf=args.fix_nan_and_inf) mode = args.reduction_mode.lower() roi_args = make_subargs(args, ['y', 'height', 'y_step']) set_node_props(reader, args) set_node_props(dark_reader, roi_args) set_node_props(flat_before_reader, roi_args) for r, path in ((reader, args.projections), (dark_reader, args.darks), (flat_before_reader, args.flats)): setup_read_task(r, path, args) LOG.debug("Doing flat field correction using reduction mode `{}'".format(mode)) if args.flats2: flat_after_reader = get_task('read') setup_read_task(flat_after_reader, args.flats2, args) set_node_props(flat_after_reader, roi_args) num_files = len(get_filenames(args.projections)) can_read = len(list(range(args.start, num_files, args.step))) number = args.number if args.number else num_files num_read = min(can_read, number) flat_interpolate = get_task('interpolate', processing_node=processing_node, number=num_read) if args.resize: LOG.debug("Resize input data by factor of {}".format(args.resize)) proj_bin = get_task('bin', processing_node=processing_node, size=args.resize) dark_bin = get_task('bin', processing_node=processing_node, size=args.resize) flat_bin = get_task('bin', processing_node=processing_node, size=args.resize) graph.connect_nodes(reader, proj_bin) graph.connect_nodes(dark_reader, dark_bin) graph.connect_nodes(flat_before_reader, flat_bin) reader, dark_reader, flat_before_reader = proj_bin, dark_bin, flat_bin if args.flats2: flat_bin = get_task('bin', processing_node=processing_node, size=args.resize) graph.connect_nodes(flat_after_reader, flat_bin) flat_after_reader = flat_bin if mode == 'median': dark_stack = get_task('stack', processing_node=processing_node, number=len(get_filenames(args.darks))) dark_reduced = get_task('flatten', processing_node=processing_node, mode='median') flat_before_stack = get_task('stack', processing_node=processing_node, number=len(get_filenames(args.flats))) flat_before_reduced = get_task('flatten', processing_node=processing_node, mode='median') graph.connect_nodes(dark_reader, dark_stack) graph.connect_nodes(dark_stack, dark_reduced) graph.connect_nodes(flat_before_reader, flat_before_stack) graph.connect_nodes(flat_before_stack, flat_before_reduced) if args.flats2: flat_after_stack = get_task('stack', processing_node=processing_node, number=len(get_filenames(args.flats2))) flat_after_reduced = get_task('flatten', processing_node=processing_node, mode='median') graph.connect_nodes(flat_after_reader, flat_after_stack) graph.connect_nodes(flat_after_stack, flat_after_reduced) elif mode == 'average': dark_reduced = get_task('average', processing_node=processing_node) flat_before_reduced = get_task('average', processing_node=processing_node) graph.connect_nodes(dark_reader, dark_reduced) graph.connect_nodes(flat_before_reader, flat_before_reduced) if args.flats2: flat_after_reduced = get_task('average', processing_node=processing_node) graph.connect_nodes(flat_after_reader, flat_after_reduced) else: raise ValueError('Invalid reduction mode') graph.connect_nodes_full(reader, ffc, 0) graph.connect_nodes_full(dark_reduced, ffc, 1) if args.flats2: graph.connect_nodes_full(flat_before_reduced, flat_interpolate, 0) graph.connect_nodes_full(flat_after_reduced, flat_interpolate, 1) graph.connect_nodes_full(flat_interpolate, ffc, 2) else: graph.connect_nodes_full(flat_before_reduced, ffc, 2) return ffc def create_phase_retrieval_pipeline(args, graph, processing_node=None): LOG.debug('Creating phase retrieval pipeline') pm = Ufo.PluginManager() # Retrieve phase phase_retrieve = get_task('retrieve-phase', processing_node=processing_node) pad_phase_retrieve = get_task('pad', processing_node=processing_node) crop_phase_retrieve = get_task('crop', processing_node=processing_node) fft_phase_retrieve = get_task('fft', processing_node=processing_node) ifft_phase_retrieve = get_task('ifft', processing_node=processing_node) last = crop_phase_retrieve width = args.width height = args.height default_padded_width = next_power_of_two(width + 64) default_padded_height = next_power_of_two(height + 64) if not args.retrieval_padded_width: args.retrieval_padded_width = default_padded_width if not args.retrieval_padded_height: args.retrieval_padded_height = default_padded_height fmt = 'Phase retrieval padding: {}x{} -> {}x{}' LOG.debug(fmt.format(width, height, args.retrieval_padded_width, args.retrieval_padded_height)) x = (args.retrieval_padded_width - width) // 2 y = (args.retrieval_padded_height - height) // 2 pad_phase_retrieve.props.x = x pad_phase_retrieve.props.y = y pad_phase_retrieve.props.width = args.retrieval_padded_width pad_phase_retrieve.props.height = args.retrieval_padded_height pad_phase_retrieve.props.addressing_mode = args.retrieval_padding_mode crop_phase_retrieve.props.x = x crop_phase_retrieve.props.y = y crop_phase_retrieve.props.width = width crop_phase_retrieve.props.height = height phase_retrieve.props.method = args.retrieval_method phase_retrieve.props.energy = args.energy if len(args.propagation_distance) == 1: phase_retrieve.props.distance = [args.propagation_distance[0]] else: phase_retrieve.props.distance_x = args.propagation_distance[0] phase_retrieve.props.distance_y = args.propagation_distance[1] phase_retrieve.props.pixel_size = args.pixel_size phase_retrieve.props.regularization_rate = args.regularization_rate phase_retrieve.props.thresholding_rate = args.thresholding_rate phase_retrieve.props.frequency_cutoff = args.frequency_cutoff fft_phase_retrieve.props.dimensions = 2 ifft_phase_retrieve.props.dimensions = 2 graph.connect_nodes(pad_phase_retrieve, fft_phase_retrieve) graph.connect_nodes(fft_phase_retrieve, phase_retrieve) graph.connect_nodes(phase_retrieve, ifft_phase_retrieve) graph.connect_nodes(ifft_phase_retrieve, crop_phase_retrieve) calculate = get_task('calculate', processing_node=processing_node) if args.delta is not None: import numpy as np lam = 6.62606896e-34 * 299792458 / (args.energy * 1.60217733e-16) thickness_conversion = -lam / (2 * np.pi * args.delta) else: thickness_conversion = 1 if args.retrieval_method == 'tie': expression = '(isinf (v) || isnan (v) || (v <= 0)) ? 0.0f : -log ({} * v) * {{}}' # 2 for 0.5 factor in ufo-filters and alpha = 10^-R, so divide by 10^R expression = expression.format(2 / 10 ** args.regularization_rate) # The following converts the TIE result to the actual phase, which when multiplied by the # thickness_conversion gives the projected thickness thickness_conversion *= -10 ** args.regularization_rate / 2 expression = expression.format(thickness_conversion) else: expression = '(isinf (v) || isnan (v)) ? 0.0f : v * {}'.format(thickness_conversion) calculate.props.expression = expression graph.connect_nodes(crop_phase_retrieve, calculate) last = calculate return (pad_phase_retrieve, last) def run_flat_correct(args): graph = Ufo.TaskGraph() sched = Ufo.Scheduler() pm = Ufo.PluginManager() out_task = get_writer(args) flat_task = create_flat_correct_pipeline(args, graph) graph.connect_nodes(flat_task, out_task) sched.run(graph) def create_sinogram_pipeline(args, graph): """Create sinogram generating pipeline based on arguments from *args*.""" pm = Ufo.PluginManager() sinos = pm.get_task('transpose-projections') if args.number: region = (args.start, args.start + args.number, args.step) num_projections = len(list(range(*region))) else: num_projections = len(get_filenames(args.projections)) sinos.props.number = num_projections if args.darks and args.flats: start = create_flat_correct_pipeline(args, graph) else: start = get_task('read') start.props.path = args.projections set_node_props(start, args) graph.connect_nodes(start, sinos) return sinos def run_sinogram_generation(args): """Make the sinograms with arguments provided by *args*.""" if not args.height: args.height = determine_shape(args, args.projections)[1] - args.y step = args.y_step * args.pass_size if args.pass_size else args.height starts = list(range(args.y, args.y + args.height, step)) + [args.y + args.height] def generate_partial(append=False): graph = Ufo.TaskGraph() sched = Ufo.Scheduler() args.output_append = append writer = get_writer(args) sinos = create_sinogram_pipeline(args, graph) graph.connect_nodes(sinos, writer) sched.run(graph) for i in range(len(starts) - 1): args.y = starts[i] args.height = starts[i + 1] - starts[i] generate_partial(append=i != 0) def create_projection_filtering_pipeline(args, graph, processing_node=None): pm = Ufo.PluginManager() pad = get_task('pad', processing_node=processing_node) fft = get_task('fft', processing_node=processing_node) ifft = get_task('ifft', processing_node=processing_node) fltr = get_task('filter', processing_node=processing_node) if args.projection_crop_after == 'filter': crop = get_task('crop', processing_node=processing_node) else: crop = None padding_width = setup_padding(pad, args.width, args.height, args.projection_padding_mode, crop=crop)[0] fft.props.dimensions = 1 ifft.props.dimensions = 1 fltr.props.filter = args.projection_filter fltr.props.scale = args.projection_filter_scale fltr.props.cutoff = args.projection_filter_cutoff graph.connect_nodes(pad, fft) graph.connect_nodes(fft, fltr) graph.connect_nodes(fltr, ifft) if crop: graph.connect_nodes(ifft, crop) last = crop else: last = ifft return (pad, last) def create_preprocessing_pipeline(args, graph, source=None, processing_node=None, cone_beam_weight=True, make_reader=True): """If *make_reader* is True, create a read task if *source* is None and no dark and flat fields are given. """ import numpy as np if not (args.width and args.height): width, height = determine_shape(args, args.projections) if not width: raise RuntimeError("Could not determine width from the input") if not args.width: args.width = width if not args.height: args.height = height - args.y LOG.debug('Image width x height: %d x %d', args.width, args.height) current = None if source: current = source elif args.darks and args.flats: current = create_flat_correct_pipeline(args, graph, processing_node=processing_node) else: if make_reader: current = get_task('read') set_node_props(current, args) if not args.projections: raise RuntimeError('--projections not set') setup_read_task(current, args.projections, args) if args.absorptivity: absorptivity = get_task('calculate', processing_node=processing_node) absorptivity.props.expression = 'v <= 0 ? 0.0f : -log(v)' if current: graph.connect_nodes(current, absorptivity) current = absorptivity if args.transpose_input: transpose = get_task('transpose') if current: graph.connect_nodes(current, transpose) current = transpose tmp = args.width args.width = args.height args.height = tmp if cone_beam_weight and not np.all(np.isinf(args.source_position_y)): # Cone beam projection weight LOG.debug('Enabling cone beam weighting') weight = get_task('cone-beam-projection-weight', processing_node=processing_node) weight.props.source_distance = (-np.array(args.source_position_y)).tolist() weight.props.detector_distance = args.detector_position_y weight.props.center_position_x = args.center_position_x or [args.width / 2. + (args.width % 2) * 0.5] weight.props.center_position_z = args.center_position_z or [args.height / 2. + (args.height % 2) * 0.5] weight.props.axis_angle_x = args.axis_angle_x if current: graph.connect_nodes(current, weight) current = weight if args.energy is not None and args.propagation_distance is not None: pr_first, pr_last = create_phase_retrieval_pipeline(args, graph, processing_node=processing_node) if current: graph.connect_nodes(current, pr_first) current = pr_last if args.projection_filter != 'none': pf_first, pf_last = create_projection_filtering_pipeline(args, graph, processing_node=processing_node) if current: graph.connect_nodes(current, pf_first) current = pf_last return current def run_preprocessing(args): graph = Ufo.TaskGraph() sched = Ufo.Scheduler() pm = Ufo.PluginManager() out_task = get_writer(args) current = create_preprocessing_pipeline(args, graph) graph.connect_nodes(current, out_task) sched.run(graph) tofu-0.12.0/tofu/reco.py000066400000000000000000000267411423713721100150410ustar00rootroot00000000000000import os import logging import glob import tempfile import sys import numpy as np from gi.repository import Ufo from tofu.preprocess import create_flat_correct_pipeline from tofu.util import (set_node_props, setup_read_task, get_filenames, read_image, determine_shape, setup_padding) from tofu.tasks import get_task, get_writer LOG = logging.getLogger(__name__) pm = Ufo.PluginManager() def get_dummy_reader(params): if params.width is None and params.height is None: raise RuntimeError("You have to specify --width and --height when generating data.") width, height = params.width, params.height reader = get_task('dummy-data', width=width, height=height, number=params.number or 1) return reader, width, height def get_file_reader(params): reader = pm.get_task('read') set_node_props(reader, params) return reader def get_projection_reader(params): reader = get_file_reader(params) setup_read_task(reader, params.projections, params) width, height = determine_shape(params, params.projections) return reader, width, height def get_sinogram_reader(params): reader = get_file_reader(params) setup_read_task(reader, params.sinograms, params) width, height = determine_shape(params, path=params.sinograms) return reader, width, height def tomo(params): # Create reader and writer if params.projections and params.sinograms: raise RuntimeError("Cannot specify both --projections and --sinograms.") if params.projections is None and params.sinograms is None: reader, width, height = get_dummy_reader(params) else: if params.projections: reader, width, height = get_projection_reader(params) else: reader, width, height = get_sinogram_reader(params) axis = params.axis or width / 2.0 if params.projections and params.resize: width /= params.resize height /= params.resize axis /= params.resize LOG.debug("Input dimensions: {}x{} pixels".format(width, height)) writer = get_writer(params) # Setup graph depending on the chosen method and input data g = Ufo.TaskGraph() if params.projections is not None: if params.number: count = len(list(range(params.start, params.start + params.number, params.step))) else: count = len(get_filenames(params.projections)) LOG.debug("Number of projections: {}".format(count)) sino_output = get_task('transpose-projections', number=count) if params.darks and params.flats: g.connect_nodes(create_flat_correct_pipeline(params, g), sino_output) else: g.connect_nodes(reader, sino_output) if height: # Sinogram height is the one needed for further padding height = count else: sino_output = reader if params.method == 'fbp': fft = get_task('fft', dimensions=1) ifft = get_task('ifft', dimensions=1) fltr = get_task('filter', filter=params.projection_filter, cutoff=params.projection_filter_cutoff) bp = get_task('backproject', axis_pos=axis) last_node = bp if params.angle: bp.props.angle_step = params.angle if params.offset: bp.props.angle_offset = params.offset if width and height: # Pad the image with its extent to prevent reconstuction ring pad = get_task('pad') crop = get_task('crop') if params.projection_crop_after == 'filter': crop_after_filter = crop else: crop_after_filter = None padding_width = setup_padding(pad, width, height, params.projection_padding_mode, crop=crop_after_filter)[0] LOG.debug("Padding input to: {}x{} pixels".format(pad.props.width, pad.props.height)) g.connect_nodes(sino_output, pad) g.connect_nodes(pad, fft) g.connect_nodes(fft, fltr) g.connect_nodes(fltr, ifft) if crop_after_filter: g.connect_nodes(ifft, crop) g.connect_nodes(crop, bp) else: bp.props.axis_pos = axis + padding_width / 2 crop.props.x = padding_width // 2 crop.props.y = padding_width // 2 crop.props.width = width crop.props.height = width g.connect_nodes(ifft, bp) g.connect_nodes(bp, crop) last_node = crop else: if params.crop_width: ifft.props.crop_width = int(params.crop_width) LOG.debug("Cropping to {} pixels".format(ifft.props.crop_width)) g.connect_nodes(sino_output, fft) g.connect_nodes(fft, fltr) g.connect_nodes(fltr, ifft) g.connect_nodes(ifft, bp) g.connect_nodes(last_node, writer) if params.method in ('sart', 'sirt', 'sbtv', 'asdpocs'): projector = pm.get_task_from_package('ir', 'parallel-projector') projector.set_properties(model='joseph', is_forward=False) projector.set_properties(axis_position=axis) projector.set_properties(step=params.angle if params.angle else np.pi / 180.0) method = pm.get_task_from_package('ir', params.method) method.set_properties(projector=projector, num_iterations=params.num_iterations) if params.method in ('sart', 'sirt'): method.set_properties(relaxation_factor=params.relaxation_factor) if params.method == 'asdpocs': minimizer = pm.get_task_from_package('ir', 'sirt') method.set_properties(df_minimizer=minimizer) if params.method == 'sbtv': # FIXME: the lambda keyword is preventing from the following # assignment ... # method.props.lambda = params.lambda method.set_properties(mu=params.mu) g.connect_nodes(sino_output, method) g.connect_nodes(method, writer) if params.method == 'dfi': oversampling = params.oversampling or 1 pad = get_task('zeropad', center_of_rotation=axis, oversampling=oversampling) fft = get_task('fft', dimensions=1, auto_zeropadding=0) dfi = get_task('dfi-sinc') ifft = get_task('ifft', dimensions=2) swap_forward = get_task('swap-quadrants') swap_backward = get_task('swap-quadrants') if params.angle: dfi.props.angle_step = params.angle g.connect_nodes(sino_output, pad) g.connect_nodes(pad, fft) g.connect_nodes(fft, dfi) g.connect_nodes(dfi, swap_forward) g.connect_nodes(swap_forward, ifft) g.connect_nodes(ifft, swap_backward) if width: crop = get_task('crop') crop.set_properties(from_center=True, width=width, height=width) g.connect_nodes(swap_backward, crop) g.connect_nodes(crop, writer) else: g.connect_nodes(swap_backward, writer) scheduler = Ufo.Scheduler() if hasattr(scheduler.props, 'enable_tracing'): LOG.debug("Use tracing: {}".format(params.enable_tracing)) scheduler.props.enable_tracing = params.enable_tracing scheduler.run(g) duration = scheduler.props.time LOG.info("Execution time: {} s".format(duration)) return duration def estimate_center(params): if params.estimate_method == 'reconstruction': axis = estimate_center_by_reconstruction(params) else: axis = estimate_center_by_correlation(params) return axis def estimate_center_by_reconstruction(params): if params.projections is not None: raise RuntimeError("Cannot estimate axis from projections") sinos = sorted(glob.glob(os.path.join(params.sinograms, '*.tif'))) if not sinos: raise RuntimeError("No sinograms found in {}".format(params.sinograms)) # Use a sinogram that probably has some interesting data filename = sinos[len(sinos) // 2] sinogram = read_image(filename) initial_width = sinogram.shape[1] m0 = np.mean(np.sum(sinogram, axis=1)) center = initial_width / 2.0 width = initial_width / 2.0 new_center = center tmp_dir = tempfile.mkdtemp() tmp_output = os.path.join(tmp_dir, 'slice-0.tif') params.sinograms = filename params.output = os.path.join(tmp_dir, 'slice-%i.tif') def heaviside(A): return (A >= 0.0) * 1.0 def get_score(guess, m0): # Run reconstruction with new guess params.axis = guess tomo(params) # Analyse reconstructed slice result = read_image(tmp_output) Q_IA = float(np.sum(np.abs(result)) / m0) Q_IN = float(-np.sum(result * heaviside(-result)) / m0) LOG.info("Q_IA={}, Q_IN={}".format(Q_IA, Q_IN)) return Q_IA def best_center(center, width): trials = [center + (width / 4.0) * x for x in range(-2, 3)] scores = [(guess, get_score(guess, m0)) for guess in trials] LOG.info(scores) best = sorted(scores, cmp=lambda x, y: cmp(x[1], y[1])) return best[0][0] for i in range(params.num_iterations): LOG.info("Estimate iteration: {}".format(i)) new_center = best_center(new_center, width) LOG.info("Currently best center: {}".format(new_center)) width /= 2.0 try: os.remove(tmp_output) os.removedirs(tmp_dir) except OSError: LOG.info("Could not remove {} or {}".format(tmp_output, tmp_dir)) return new_center def estimate_center_by_correlation(params): """Use correlation to estimate center of rotation for tomography.""" def flat_correct(flat, radio): nonzero = np.where(radio != 0) result = np.zeros_like(radio) result[nonzero] = flat[nonzero] / radio[nonzero] # log(1) = 0 result[result <= 0] = 1 return np.log(result) first = read_image(get_filenames(params.projections)[0]).astype(np.float) last_index = params.start + params.number if params.number else -1 last = read_image(get_filenames(params.projections)[last_index]).astype(np.float) if params.darks and params.flats: dark = read_image(get_filenames(params.darks)[0]).astype(np.float) flat = read_image(get_filenames(params.flats)[0]) - dark first = flat_correct(flat, first - dark) last = flat_correct(flat, last - dark) height = params.height if params.height else -1 y_region = slice(params.y, min(params.y + height, first.shape[0]), params.y_step) first = first[y_region, :] last = last[y_region, :] return compute_rotation_axis(first, last) def compute_rotation_axis(first_projection, last_projection): """ Compute the tomographic rotation axis based on cross-correlation technique. *first_projection* is the projection at 0 deg, *last_projection* is the projection at 180 deg. """ from scipy.signal import fftconvolve width = first_projection.shape[1] first_projection = first_projection - first_projection.mean() last_projection = last_projection - last_projection.mean() # The rotation by 180 deg flips the image horizontally, in order # to do cross-correlation by convolution we must also flip it # vertically, so the image is transposed and we can apply convolution # which will act as cross-correlation convolved = fftconvolve(first_projection, last_projection[::-1, :], mode='same') center = np.unravel_index(convolved.argmax(), convolved.shape)[1] return (width / 2.0 + center) / 2 tofu-0.12.0/tofu/tasks.py000066400000000000000000000024431423713721100152270ustar00rootroot00000000000000import logging from gi.repository import Ufo LOG = logging.getLogger(__name__) PLUGIN_MANAGER = Ufo.PluginManager() def get_task(name, processing_node=None, **kwargs): task = PLUGIN_MANAGER.get_task(name) task.set_properties(**kwargs) if processing_node and task.uses_gpu(): LOG.debug("Assigning task '%s' to node %d", name, processing_node.get_index()) task.set_proc_node(processing_node) return task def get_writer(params): if 'dry_run' in params and params.dry_run: LOG.debug("Discarding data output") return get_task('null', download=True) outname = params.output LOG.debug("Writing output to {}".format(outname)) writer = get_task('write', filename=outname) writer.props.append = params.output_append if params.output_bitdepth != 32: writer.props.bits = params.output_bitdepth if params.output_minimum is not None and params.output_maximum is not None: writer.props.minimum = params.output_minimum writer.props.maximum = params.output_maximum if hasattr (writer.props, 'bytes_per_file'): writer.props.bytes_per_file = params.output_bytes_per_file if hasattr(writer.props, 'tiff_bigtiff'): writer.props.tiff_bigtiff = params.output_bytes_per_file > 2 ** 32 return writer tofu-0.12.0/tofu/tests/000077500000000000000000000000001423713721100146675ustar00rootroot00000000000000tofu-0.12.0/tofu/tests/__init__.py000066400000000000000000000000001423713721100167660ustar00rootroot00000000000000tofu-0.12.0/tofu/tests/composites/000077500000000000000000000000001423713721100170545ustar00rootroot00000000000000tofu-0.12.0/tofu/tests/composites/cmp.cm000066400000000000000000000056661423713721100201710ustar00rootroot00000000000000{ "name": "cmp", "caption": "cmp", "models": { "Null": { "model": { "caption": "Null", "properties": { "download": [ false, true ], "finish": [ false, true ], "durations": [ false, true ] } }, "visible": true, "position": { "x": 756.0, "y": 272.0 }, "name": "null" }, "Read": { "model": { "caption": "Read", "properties": { "path": [ ".", true ], "start": [ 0, false ], "number": [ 4294967295, true ], "step": [ 1, false ], "y": [ 0, false ], "height": [ 0, false ], "y-step": [ 1, false ], "convert": [ true, false ], "raw-width": [ 0, false ], "raw-height": [ 0, false ], "raw-bitdepth": [ 0, false ], "raw-pre-offset": [ 0, false ], "raw-post-offset": [ 0, false ], "type": [ "unspecified", false ], "retries": [ 0, false ], "retry-timeout": [ 1, false ] } }, "visible": true, "position": { "x": 266.0, "y": 242.0 }, "name": "read" } }, "connections": [ [ "Read", 0, "Null", 0 ] ], "links": [] }tofu-0.12.0/tofu/tests/composites/cmp_2.cm000066400000000000000000000034751423713721100204060ustar00rootroot00000000000000{ "name": "cmp_2", "caption": "cmp_2", "models": { "Dummy Data": { "model": { "caption": "Dummy Data", "properties": { "width": [ 1, true ], "height": [ 1, true ], "depth": [ 1, true ], "number": [ 1, true ], "init": [ 0.0, true ], "metadata": [ false, true ] } }, "visible": true, "position": { "x": 359.0, "y": 357.0 }, "name": "dummy_data" }, "Null": { "model": { "caption": "Null", "properties": { "download": [ false, true ], "finish": [ false, true ], "durations": [ false, true ] } }, "visible": true, "position": { "x": 859.0, "y": 459.0 }, "name": "null" } }, "connections": [ [ "Dummy Data", 0, "Null", 0 ] ], "links": [] }tofu-0.12.0/tofu/tests/conftest.py000066400000000000000000000030441423713721100170670ustar00rootroot00000000000000import pytest from PyQt5.QtWidgets import QInputDialog from tofu.flow.main import get_filled_registry from tofu.flow.scene import UfoScene from tofu.flow.propertylinksmodels import PropertyLinksModel, NodeTreeModel @pytest.fixture(scope='function') def nodes(monkeypatch): reg = get_filled_registry() scene = UfoScene(reg) nodes = {} # Composite node for name in ['read', 'pad']: model_cls = reg.create(name) node = scene.create_node(model_cls) node.graphics_object.setSelected(True) monkeypatch.setattr(QInputDialog, "getText", lambda *args: ('cpm', True)) nodes['cpm'] = scene.create_composite() nodes['cpm'].graphics_object.setSelected(False) # Simple nodes for i in range(5): name = f'read_{i}' if i else 'read' model_cls = reg.create('read') nodes[name] = scene.create_node(model_cls) model_cls = reg.create('image_viewer') nodes['image_viewer'] = scene.create_node(model_cls) model_cls = reg.create('average') nodes['average'] = scene.create_node(model_cls) return nodes @pytest.fixture(scope='function') def scene(): reg = get_filled_registry() return UfoScene(reg) @pytest.fixture(scope='function') def scene_with_composite(nodes): return UfoScene(nodes['cpm'].model._registry) @pytest.fixture(scope='function') def node_model(): model = NodeTreeModel() model.setColumnCount(1) return model @pytest.fixture(scope='function') def link_model(node_model): model = PropertyLinksModel(node_model) return model tofu-0.12.0/tofu/tests/flow_util.py000066400000000000000000000016661423713721100172560ustar00rootroot00000000000000def populate_link_model(link_model, nodes): read = nodes['read'] read_2 = nodes['read_2'] composite = nodes['cpm'] records = [[read, read.model, 'number'], [read_2, read_2.model, 'height'], [composite, composite.model['Read'], 'y']] for (i, (node, model, prop)) in enumerate(records): link_model.add_item(node, model, prop, 0, i) return records def get_index_from_treemodel(node_model, row, prop_name): item = node_model.item(row, 0) i = 0 prop_item = item.child(i) while prop_item.text() != prop_name: i += 1 prop_item = item.child(i) return node_model.indexFromItem(prop_item) def add_nodes_to_scene(scene, model_names=None): if not model_names: model_names = ['read'] nodes = [] for name in model_names: model_cls = scene.registry.create(name) nodes.append(scene.create_node(model_cls)) return nodes tofu-0.12.0/tofu/tests/test_flow_execution.py000066400000000000000000000114371423713721100213400ustar00rootroot00000000000000import pytest from tofu.flow.execution import get_gpu_splitting_models, UfoExecutor from tofu.flow.main import get_filled_registry from tofu.flow.scene import UfoScene @pytest.fixture(scope='function') def scene(): reg = get_filled_registry() scene = UfoScene(reg) for name in ['dummy_data', 'pad', 'null']: # Set nodes as scene attributes for convenience setattr(scene, name, scene.create_node(reg.create(name))) scene.create_connection(scene.dummy_data['output'][0], scene.pad['input'][0]) scene.create_connection(scene.pad['output'][0], scene.null['input'][0]) return scene @pytest.fixture(scope='function') def executor(): return UfoExecutor() class TestUfoExecutor: def test_init(self, executor): ... def test_reset(self, executor): assert not executor._aborted assert executor._schedulers == [] assert executor.num_generated == 0 def test_abort(self, executor): self.called = False def slot(): self.called = True executor.execution_finished.connect(slot) executor.abort() assert self.called def test_on_processed(self, executor): self.num_generated = 0 def slot(): self.num_generated += 1 executor.processed_signal.connect(slot) executor.on_processed(None) executor.on_processed(None) assert self.num_generated == executor.num_generated == 2 def test_setup_ufo_graph(self, qtbot, scene, executor): graph = scene.get_simple_node_graphs()[0] gpus = executor._resources.get_gpu_nodes() assert gpus executor.setup_ufo_graph(graph, gpu=gpus[0], region=None, signalling_model=scene.dummy_data.model) def test_run_ufo_graph(self, qtbot, scene, executor): graph = scene.get_simple_node_graphs()[0] gpus = executor._resources.get_gpu_nodes() assert gpus ufo_graph = executor.setup_ufo_graph(graph, gpu=gpus[0], region=None, signalling_model=scene.dummy_data.model) # Run with default scheduler executor._run_ufo_graph(ufo_graph, False) # Run with fixed scheduler executor._run_ufo_graph(ufo_graph, True) # def test_check_graph(self, qtbot, scene, executor): # # TODO: implement this when memory-in is implemented and there is something to test def test_run(self, qtbot, scene, executor): def on_num_inputs_changed(number): self.num_inputs = number def on_processed(number): self.num_processed = number def on_execution_started(): self.started = True def on_execution_finished(): self.finished = True def on_exception_occured(): self.exception = True scene.dummy_data.model['number'] = 10 graph = scene.get_simple_node_graphs()[0] self.num_inputs = 0 self.num_processed = 0 self.started = False self.finished = False self.exception = None executor.number_of_inputs_changed.connect(on_num_inputs_changed) executor.processed_signal.connect(on_processed) executor.execution_started.connect(on_execution_started) executor.execution_finished.connect(on_execution_finished) executor.exception_occured.connect(on_exception_occured) with qtbot.waitSignal(signal=executor.execution_finished, timeout=100000): executor.run(graph) assert self.num_inputs == scene.dummy_data.model['number'] assert self.num_processed == scene.dummy_data.model['number'] assert self.started assert self.finished assert self.exception is None scene.remove_node(scene.dummy_data) # Create a reader and point it to a nonexistent path so that it raises an exception and # check that this exception has been processed byt the executor setattr(scene, 'read', scene.create_node(scene.registry.create('read'))) scene.create_connection(scene.read['output'][0], scene.pad['input'][0]) # Make sure the path is nonsense scene.read.model['path'] = '/dfasf/fsdafsdaf/asd/asf' scene.read.model['number'] = 10 graph = scene.get_simple_node_graphs()[0] executor.swallow_run_exceptions = True with qtbot.waitSignal(signal=executor.execution_finished): executor.run(graph) assert self.exception def test_get_gpu_splitting_models(qtbot, scene, executor): graph = scene.get_simple_node_graphs()[0] assert len(get_gpu_splitting_models(graph)) == 0 scene.clear_scene() scene.create_node(scene.registry.create('general_backproject')) graph = scene.get_simple_node_graphs()[0] assert len(get_gpu_splitting_models(graph)) == 1 tofu-0.12.0/tofu/tests/test_flow_main.py000066400000000000000000000461271423713721100202650ustar00rootroot00000000000000import glob import os import pathlib import pkg_resources import pytest import sys from PyQt5.QtWidgets import QFileDialog, QInputDialog, QMessageBox from xdg import xdg_data_home from tofu.flow.execution import UfoExecutor from tofu.flow.main import ApplicationWindow, get_filled_registry, GlobalExceptionHandler from tofu.flow.scene import UfoScene from tofu.flow.util import FlowError from tofu.tests.flow_util import add_nodes_to_scene @pytest.fixture(scope='function') def app_window(qtbot, scene): window = ApplicationWindow(scene) qtbot.addWidget(window) return window class TestApplicationWindow: def test_init(self, qtbot, app_window): assert app_window.ufo_scene assert app_window.executor def test_on_save(self, monkeypatch, app_window): def getSaveFileNameDefault(inst, header, path, fltr): return (os.path.join(path, 'flow.flow'), True) def getSaveFileName(inst, header, path, fltr): return (os.path.join('foo', 'bar', 'flow.flow'), True) # Don't actually write to disk monkeypatch.setattr(UfoScene, "save", lambda *args: None) # Default directory monkeypatch.setattr(QFileDialog, "getSaveFileName", getSaveFileNameDefault) app_window.on_save() directory = os.path.join(xdg_data_home(), 'tofu', 'flows') assert os.path.exists(directory) assert app_window.last_dirs['scene'] == directory # When user picks a different directory it must be remembered monkeypatch.setattr(QFileDialog, "getSaveFileName", getSaveFileName) app_window.on_save() assert app_window.last_dirs['scene'] == os.path.join('foo', 'bar') # And used the next time monkeypatch.setattr(QFileDialog, "getSaveFileName", getSaveFileNameDefault) app_window.on_save() assert app_window.last_dirs['scene'] == os.path.join('foo', 'bar') def test_on_open(self, monkeypatch, app_window): def getOpenFileNameDefault(inst, header, path, fltr): return (os.path.join(path, 'flow.flow'), True) def getOpenFileName(inst, header, path, fltr): return (os.path.join('foo', 'bar', 'flow.flow'), True) # Don't actually read from disk monkeypatch.setattr(UfoScene, "load", lambda *args: None) # Default directory monkeypatch.setattr(QFileDialog, "getOpenFileName", getOpenFileNameDefault) app_window.on_open() directory = os.path.join(xdg_data_home(), 'tofu', 'flows') if not os.path.exists(directory): directory = pathlib.Path.home() assert app_window.last_dirs['scene'] == directory # When user picks a different directory it must be remembered monkeypatch.setattr(QFileDialog, "getOpenFileName", getOpenFileName) app_window.on_open() assert app_window.last_dirs['scene'] == os.path.join('foo', 'bar') # And used the next time monkeypatch.setattr(QFileDialog, "getOpenFileName", getOpenFileNameDefault) app_window.on_open() assert app_window.last_dirs['scene'] == os.path.join('foo', 'bar') def test_on_exception_occured(self, qtbot, monkeypatch, app_window): def exec_(inst): self.message_shown = True self.message_shown = False monkeypatch.setattr(QMessageBox, "exec_", exec_) app_window.on_exception_occured('foo') assert self.message_shown def test_on_number_of_inputs_changed(self, qtbot, app_window): app_window.on_number_of_inputs_changed(123) assert app_window.progress_bar.maximum() == 123 def test_on_processed(self, qtbot, app_window): app_window.on_number_of_inputs_changed(100) app_window.on_processed(10) assert app_window.progress_bar.value() == 11 def test_on_nodes_duplicated(self, qtbot, app_window): node = add_nodes_to_scene(app_window.ufo_scene)[0] node.graphics_object.setSelected(True) app_window.ufo_scene.copy_nodes() nodes = list(app_window.ufo_scene.nodes.values()) assert nodes[0].graphics_object.pos().y() != nodes[1].graphics_object.pos().y() def test_on_selection_menu_about_to_show(self, qtbot, monkeypatch, app_window): # Nothing selected app_window.on_selection_menu_about_to_show() assert not app_window.edit_composite_action.isEnabled() assert not app_window.expand_composite_action.isEnabled() assert not app_window.export_composite_action.isEnabled() # Only non-composite nodes nodes = add_nodes_to_scene(app_window.ufo_scene, model_names=['read', 'average', 'null']) app_window.on_selection_menu_about_to_show() assert not app_window.edit_composite_action.isEnabled() assert not app_window.expand_composite_action.isEnabled() assert not app_window.export_composite_action.isEnabled() # One composite monkeypatch.setattr(QInputDialog, "getText", lambda *args: ('cpm', True)) for i in range(2): nodes[i].graphics_object.setSelected(True) app_window.ufo_scene.create_composite() app_window.on_selection_menu_about_to_show() assert app_window.edit_composite_action.isEnabled() assert app_window.expand_composite_action.isEnabled() assert app_window.export_composite_action.isEnabled() # More composites monkeypatch.setattr(QInputDialog, "getText", lambda *args: ('cpm_2', True)) app_window.ufo_scene.clearSelection() nodes[-1].graphics_object.setSelected(True) app_window.ufo_scene.create_composite() for node in app_window.ufo_scene.nodes.values(): node.graphics_object.setSelected(True) app_window.on_selection_menu_about_to_show() assert not app_window.edit_composite_action.isEnabled() assert app_window.expand_composite_action.isEnabled() assert not app_window.export_composite_action.isEnabled() def test_skip_action(self, qtbot, app_window): # No nodes selected, menu item must be disabled app_window.on_selection_menu_about_to_show() assert not app_window.skip_action.isEnabled() # Add some nodes, conect them and disable one nodes = add_nodes_to_scene(app_window.ufo_scene, model_names=['read', 'average', 'null']) app_window.ufo_scene.create_connection(nodes[0]['output'][0], nodes[1]['input'][0]) app_window.ufo_scene.create_connection(nodes[1]['output'][0], nodes[2]['input'][0]) average = nodes[1] average.graphics_object.setSelected(True) app_window.on_selection_menu_about_to_show() # Nodes selected, menu item must be enabled assert app_window.skip_action.isEnabled() def test_on_edit_composite(self, qtbot, scene_with_composite, app_window): app_window.ufo_scene = scene_with_composite node = add_nodes_to_scene(app_window.ufo_scene, model_names=['cpm'])[0] node.graphics_object.setSelected(True) app_window.on_edit_composite() qtbot.addWidget(node.model._other_view) assert node.model.is_editing def test_on_create_composite(self, qtbot, monkeypatch, scene_with_composite, app_window): monkeypatch.setattr(QInputDialog, "getText", lambda *args: ('cpm', True)) nodes = add_nodes_to_scene(app_window.ufo_scene, model_names=['read', 'pad']) # Link a model to the slider model = nodes[0].model view_item = model._view._properties['number'].view_item app_window.on_item_focus_in(view_item, 'number', 'Read', model) # Create a composite for node in app_window.ufo_scene.nodes.values(): node.graphics_object.setSelected(True) app_window.on_create_composite() composite = list(app_window.ufo_scene.nodes.values())[0].model slider_model, prop_name = app_window.run_slider_key assert slider_model == composite.get_model_from_path(['Read']) assert prop_name == 'number' def test_on_item_focus_in(self, qtbot, app_window, scene_with_composite): read, pad = add_nodes_to_scene(app_window.ufo_scene, model_names=['read', 'pad']) # Simple node model = read.model view_item = model._view._properties['number'].view_item app_window.on_item_focus_in(view_item, 'number', model.caption, model) slider_model, prop_name = app_window.run_slider_key assert slider_model == model assert prop_name == 'number' app_window.fix_run_slider.setChecked(False) model = pad.model view_item = model._view._properties['y'].view_item app_window.on_item_focus_in(view_item, 'y', model.caption, model) slider_model, prop_name = app_window.run_slider_key assert slider_model == model assert prop_name == 'y' # Focus gets another widget, but the run slider must be linked to the one focused before the # fix option is checked app_window.fix_run_slider.setChecked(True) model = read.model view_item = model._view._properties['number'].view_item app_window.on_item_focus_in(view_item, 'number', model.caption, model) slider_model, prop_name = app_window.run_slider_key assert slider_model == pad.model assert prop_name == 'y' def test_on_node_deleted(self, qtbot, monkeypatch, app_window, scene_with_composite): app_window.ufo_scene = scene_with_composite cpm, cpm_2, read = add_nodes_to_scene(app_window.ufo_scene, model_names=['cpm', 'cpm', 'read']) # Simple node model = read.model view_item = model._view._properties['number'].view_item app_window.on_item_focus_in(view_item, 'number', model.caption, model) # remove in the scene doesn't seem to emit the signal, so use the window app_window.on_node_deleted(read) slider_model, prop_name = app_window.run_slider_key assert slider_model is None assert prop_name is None # Composite node model = cpm.model.get_model_from_path(['Read']) view_item = model._view._properties['number'].view_item app_window.on_item_focus_in(view_item, 'number', 'cpm->Read', model) # remove in the scene doesn't seem to emit the signal, so use the window app_window.on_node_deleted(cpm) slider_model, prop_name = app_window.run_slider_key assert slider_model is None assert prop_name is None # Nested composite node cpm_2.graphics_object.setSelected(True) monkeypatch.setattr(QInputDialog, "getText", lambda *args: ('parent', True)) app_window.on_create_composite() node = app_window.ufo_scene.selected_nodes()[0] model = node.model.get_model_from_path(['cpm 2', 'Read']) view_item = model._view._properties['number'].view_item app_window.on_item_focus_in(view_item, 'number', 'parent->cpm 2->Read', model) # remove in the scene doesn't seem to emit the signal, so use the window app_window.on_node_deleted(node) slider_model, prop_name = app_window.run_slider_key assert slider_model is None assert prop_name is None def test_on_expand_composite(self, qtbot, scene_with_composite, app_window): app_window.ufo_scene = scene_with_composite nodes = add_nodes_to_scene(app_window.ufo_scene, model_names=['cpm', 'cpm']) for node in nodes: node.graphics_object.setSelected(True) app_window.on_expand_composite() captions = {node.model.caption for node in app_window.ufo_scene.nodes.values()} assert captions == {'Read 2', 'Pad 2', 'Read', 'Pad'} # Run slider # Create yet another composite and select a reader inside node = add_nodes_to_scene(app_window.ufo_scene, model_names=['cpm'])[0] model = node.model.get_model_from_path(['Read']) view_item = model._view._properties['number'].view_item app_window.on_item_focus_in(view_item, 'number', 'cpm->Read', model) node.graphics_object.setSelected(True) app_window.on_expand_composite() # After expansion, the reader's index will be 3 slider_model, prop_name = app_window.run_slider_key assert slider_model.caption == 'Read 3' assert prop_name == 'number' def test_on_import_composites(self, qtbot, monkeypatch, app_window): tests_directory = pkg_resources.resource_filename(__name__, 'composites') def getOpenFileNamesDefault(inst, header, path, fltr): # Let's pretend there are files file_names = [os.path.join(path, 'foo.cm')] return (file_names, True) def getOpenFileNames(inst, header, path, fltr): file_names = sorted(glob.glob(os.path.join(tests_directory, '*.cm'))) return (file_names, True) def exec_(inst): self.message_shown = True monkeypatch.setattr(QMessageBox, "exec_", exec_) # Nothing opened, nothing happens monkeypatch.setattr(QFileDialog, "getOpenFileNames", lambda *args: ([], True)) app_window.on_import_composites() # Default directory monkeypatch.setattr(QFileDialog, "getOpenFileNames", getOpenFileNamesDefault) directory = os.path.join(xdg_data_home(), 'tofu', 'flows', 'composites') if not os.path.exists(directory): directory = pathlib.Path.home() try: app_window.on_import_composites() except FileNotFoundError: # We don't care if there are files, just the last_dirs setting is important pass assert app_window.last_dirs['composite'] == directory # It's possible to open more than one at a time monkeypatch.setattr(QFileDialog, "getOpenFileNames", getOpenFileNames) app_window.on_import_composites() assert 'cmp' in app_window.ufo_scene.registry.registered_model_creators() assert 'cmp_2' in app_window.ufo_scene.registry.registered_model_creators() # When user picks a different directory it must be remembered assert app_window.last_dirs['composite'] == tests_directory # And used the next time self.message_shown = False app_window.on_import_composites() assert app_window.last_dirs['composite'] == tests_directory # Message about overwriting models must be shown assert self.message_shown def test_on_export_composite(self, qtbot, monkeypatch, scene_with_composite, app_window): tests_directory = pkg_resources.resource_filename(__name__, 'composites') def getSaveFileNameDefault(inst, header, path, fltr): return (os.path.join(path, self.file_name), True) def getSaveFileName(inst, header, path, fltr): return (os.path.join(tests_directory, self.file_name), True) def export_composite(inst, node, file_name): self.final_file_name = file_name # Nothing selected, must silently pass app_window.on_export_composite() # Make a composite node app_window.ufo_scene = scene_with_composite node = add_nodes_to_scene(app_window.ufo_scene, model_names=['cpm'])[0] node.graphics_object.setSelected(True) monkeypatch.setattr(ApplicationWindow, "export_composite", export_composite) # Default directory monkeypatch.setattr(QFileDialog, "getSaveFileName", getSaveFileNameDefault) self.file_name = 'composite' directory = os.path.join(xdg_data_home(), 'tofu', 'flows', 'composites') app_window.on_export_composite() assert self.final_file_name.endswith('.cm') and not self.final_file_name.endswith('.cm.cm') assert os.path.exists(directory) assert app_window.last_dirs['composite'] == directory # When user picks a different directory it must be remembered monkeypatch.setattr(QFileDialog, "getSaveFileName", getSaveFileName) app_window.on_export_composite() assert self.final_file_name.endswith('.cm') and not self.final_file_name.endswith('.cm.cm') assert app_window.last_dirs['composite'] == tests_directory # And used the next time monkeypatch.setattr(QFileDialog, "getSaveFileName", getSaveFileNameDefault) app_window.on_export_composite() assert app_window.last_dirs['composite'] == tests_directory # .cm must not be added if it's present in the file name self.file_name = 'composite.cm' app_window.on_export_composite() assert self.final_file_name.endswith('.cm') and not self.final_file_name.endswith('.cm.cm') def test_on_reset_view(self, qtbot, app_window): app_window.flow_view.scale_up() app_window.on_reset_view() assert app_window.flow_view.transform().m11() == pytest.approx(1) assert app_window.flow_view.transform().m22() == pytest.approx(1) def test_on_property_links_action(self, qtbot, app_window): qtbot.addWidget(app_window.property_links_widget) app_window.property_links_widget.show() assert app_window.property_links_widget.isVisible() def test_on_run(self, qtbot, monkeypatch, app_window): def executor_run(inst, graph): self.ran = True monkeypatch.setattr(UfoExecutor, "run", executor_run) nodes = add_nodes_to_scene(app_window.ufo_scene, model_names=['read', 'read', 'flat_field_correct', 'null']) i_0, i_1, ffc, null = nodes # No connections -> many graphs with pytest.raises(FlowError): app_window.on_run() assert app_window.run_action.isEnabled() app_window.ufo_scene.create_connection(i_0['output'][0], ffc['input'][0]) app_window.ufo_scene.create_connection(i_1['output'][0], ffc['input'][1]) app_window.ufo_scene.create_connection(ffc['output'][0], null['input'][0]) # One ffc input is not connected with pytest.raises(FlowError): app_window.on_run() assert app_window.run_action.isEnabled() # All connections present -> must run i_2 = add_nodes_to_scene(app_window.ufo_scene, model_names=['read'])[0] app_window.ufo_scene.create_connection(i_2['output'][0], ffc['input'][2]) self.ran = False app_window.on_run() assert self.ran assert not app_window.run_action.isEnabled() def test_on_execution_finished(self, qtbot, app_window): app_window.run_action.setEnabled(False) app_window.progress_bar.setMaximum(100) app_window.progress_bar.setValue(50) app_window.on_execution_finished() assert app_window.progress_bar.value() == -1 assert app_window.run_action.isEnabled() def test_global_exception_handler(qtbot): handler = GlobalExceptionHandler() def slot(text): handler.called_signal = True handler.exception_occured.connect(slot) handler.called_signal = False try: raise FlowError('foo') except: # Call the hook explicitly, sys.excinfo = ... doesn't seem to have effect handler.excepthook(*sys.exc_info()) assert handler.called_signal def test_get_filled_registry(): registry = get_filled_registry() assert 'read' in registry.registered_model_creators() tofu-0.12.0/tofu/tests/test_flow_models.py000066400000000000000000001643251423713721100206250ustar00rootroot00000000000000import pytest import numpy as np from PyQt5.QtCore import Qt from PyQt5.QtGui import QValidator from PyQt5.QtWidgets import QFileDialog, QInputDialog, QLineEdit from tofu.flow.main import get_filled_registry from tofu.flow.models import (CheckBoxViewItem, ComboBoxViewItem, get_composite_model_class, get_composite_model_classes, get_composite_model_classes_from_json, get_ufo_model_class, get_ufo_model_classes, ImageViewerModel, IntQLineEditViewItem, MultiPropertyView, NumberQLineEditViewItem, PropertyModel, PropertyView, QLineEditViewItem, RangeQLineEditViewItem, UfoGeneralBackprojectModel, UfoIntValidator, UfoMemoryOutModel, UfoModelError, UfoRangeValidator, UfoReadModel, UfoRetrievePhaseModel, UfoModel, UfoTaskModel, UfoVaryingInputModel, UfoWriteModel, ViewItem) from tofu.flow.scene import UfoScene from tofu.flow.util import CompositeConnection, MODEL_ROLE, PROPERTY_ROLE from tofu.tests.flow_util import populate_link_model def check_property_changed_emit(qtbot, view_item, expected, gui_func, gui_args, gui_kwargs=None, show=False): if gui_kwargs is None: gui_kwargs = {} def on_changed(vit): vit.change_called = True view_item.change_called = False view_item.property_changed.connect(on_changed) qtbot.addWidget(view_item.widget) if show: # without show the mouse click for QCheckBox doesn't happen, bug in pytest-qt? view_item.widget.show() # Store old value for later check of programmatic change old_value = view_item.get() # Simulate user interaction gui_func(*gui_args, **gui_kwargs) # Value must have been set assert view_item.get() == expected # Signal must have been emitted assert view_item.change_called view_item.change_called = False view_item.set(old_value) # Signal must be emitted only on user interacion, not programmatic access assert not view_item.change_called def make_properties(): return { 'int': [IntQLineEditViewItem(0, 100, default_value=10), True], 'float': [NumberQLineEditViewItem(0, 100, default_value=0), True], 'string': [QLineEditViewItem(default_value='foo'), True], 'range': [RangeQLineEditViewItem(default_value=[1, 2, 3], num_items=3, is_float=True), True], 'choices': [ComboBoxViewItem(['a', 'b', 'c']), True], 'check': [CheckBoxViewItem(checked=True), True] } class DummyPropertyModel(PropertyModel): def make_properties(self): return make_properties() @pytest.fixture(scope='function') def property_view(): return PropertyView(properties=make_properties(), scrollable=False) @pytest.fixture(scope='function') def multi_property_view(nodes): groups = {nodes['cpm'].model: True, nodes['read'].model: False} return MultiPropertyView(groups=groups) def make_composite_model_class(nodes, name='foobar'): # We want to connect cpm:Pad to average, thus we need to get the outside port of the cpm # composite which corresponds to the pad model pad_index = nodes['cpm'].model.get_outside_port('Pad', 'output', 0)[1] connections = [CompositeConnection('cpm', pad_index, 'Average', 0)] state = [('cpm', nodes['cpm'].model.save(), True, None), ('average', nodes['average'].model.save(), True, None)] return get_composite_model_class(name, state, connections) def create_scene(qtbot, registry): scene = UfoScene(registry=registry) if scene.views(): for view in scene.views(): qtbot.addWidget(view) return scene def make_composite_node_in_scene(qtbot, nodes): model_cls = make_composite_model_class(nodes) registry = get_filled_registry() # Register both composites so that we can create them registry.register_model(nodes['cpm'].model.__class__, category='Composite', registry=registry) registry.register_model(model_cls, category='Composite', registry=registry) scene = create_scene(qtbot, registry) node = scene.create_node(model_cls) return (scene, node) @pytest.fixture(scope='function') def composite_model(nodes): # Make sure 'cpm', which is inside this composite model, has been registered registry = nodes['cpm'].model._registry model_cls = make_composite_model_class(nodes) registry.register_model(model_cls, category='Composite', registry=registry) return model_cls(registry=registry) @pytest.fixture(scope='function') def general_backproject(qtbot): model = UfoGeneralBackprojectModel() qtbot.addWidget(model.embedded_widget()) return model @pytest.fixture(scope='function') def read_model(qtbot): model = UfoReadModel() qtbot.addWidget(model.embedded_widget()) return model @pytest.fixture(scope='function') def write_model(qtbot): model = UfoWriteModel() qtbot.addWidget(model.embedded_widget()) return model @pytest.fixture(scope='function') def memory_out_model(qtbot): model = UfoMemoryOutModel() model['width'] = 100 model['height'] = 100 qtbot.addWidget(model.embedded_widget()) return model @pytest.fixture(scope='function') def image_viewer_model(qtbot): model = ImageViewerModel() qtbot.addWidget(model.embedded_widget()) return model def test_ufo_int_validator(): validator = UfoIntValidator(-10, 10) def check(input_str, expected): assert validator.validate(input_str, -1)[0] == expected check('0', QValidator.Acceptable) check('1', QValidator.Acceptable) check('-1', QValidator.Acceptable) check('101', QValidator.Intermediate) check('-101', QValidator.Intermediate) check('-', QValidator.Intermediate) check('1.', QValidator.Invalid) check('1.0', QValidator.Invalid) check('asdf', QValidator.Invalid) validator = UfoIntValidator(3, 10) check('1', QValidator.Intermediate) def test_ufo_range_validator(): def check(validator, input_str, expected): assert validator.validate(input_str, len(input_str))[0] == expected # Integer validator = UfoRangeValidator(num_items=3, is_float=False) check(validator, ',,', QValidator.Intermediate) check(validator, ' ,,', QValidator.Intermediate) check(validator, '1,1,', QValidator.Intermediate) check(validator, ',1,', QValidator.Intermediate) check(validator, '1,-2,3', QValidator.Acceptable) check(validator, '1,1.0,1', QValidator.Invalid) check(validator, '-1,s,-1', QValidator.Invalid) check(validator, '1,1,1,1', QValidator.Invalid) check(validator, '1,1,1,', QValidator.Invalid) # Float validator = UfoRangeValidator(num_items=3, is_float=True) check(validator, ',,', QValidator.Intermediate) check(validator, ' ,,', QValidator.Intermediate) check(validator, '.,,', QValidator.Intermediate) check(validator, '.e,,', QValidator.Intermediate) check(validator, '.e-,,', QValidator.Intermediate) check(validator, '.e+,,', QValidator.Intermediate) check(validator, '1.0e,,', QValidator.Intermediate) check(validator, '1.0e+,,', QValidator.Intermediate) check(validator, '1.0e-,,', QValidator.Intermediate) check(validator, '1e,,', QValidator.Intermediate) check(validator, '1e+,,', QValidator.Intermediate) check(validator, '1e-,,', QValidator.Intermediate) check(validator, '.1e,,', QValidator.Intermediate) check(validator, '.1e+,,', QValidator.Intermediate) check(validator, '.1e-,,', QValidator.Intermediate) check(validator, '1,1,1', QValidator.Acceptable) check(validator, '-1,1,1', QValidator.Acceptable) check(validator, '1.,1.,1', QValidator.Acceptable) check(validator, '-1.,1.,1', QValidator.Acceptable) check(validator, '1.0e1,1.0,1', QValidator.Acceptable) check(validator, '1.0e+1,1.0,1', QValidator.Acceptable) check(validator, '1.0e-1,1.0,1', QValidator.Acceptable) check(validator, '.1,1.0,1', QValidator.Acceptable) check(validator, '.1e-1,1.0,1', QValidator.Acceptable) check(validator, '.1e+1,1.0,1', QValidator.Acceptable) check(validator, '.1e1,1.0,1', QValidator.Acceptable) check(validator, 'e,,', QValidator.Invalid) check(validator, 'e.,,', QValidator.Invalid) check(validator, '+e,,', QValidator.Invalid) check(validator, '-e,,', QValidator.Invalid) check(validator, '+e.,,', QValidator.Invalid) check(validator, '-e.,,', QValidator.Invalid) check(validator, '1+,,', QValidator.Invalid) check(validator, '1-,,', QValidator.Invalid) check(validator, 'gfd,1,3', QValidator.Invalid) def test_view_item_init(qtbot): def get(inst): return inst.widget.text() def set(inst, value): inst.widget.setText(value) def on_changed(vit): vit.change_called = True ViewItem.get = get ViewItem.set = set edit = QLineEdit() qtbot.addWidget(edit) vit = ViewItem(edit, default_value='foo', tooltip='tooltip') edit.textEdited.connect(vit.on_changed) assert vit.widget.toolTip() == 'tooltip' assert vit.widget.text() == 'foo' check_property_changed_emit(qtbot, vit, 'fooa', qtbot.keyClick, (edit, 'a')) def test_check_box_view_item(qtbot): assert CheckBoxViewItem(checked=True).get() vit = CheckBoxViewItem(checked=False, tooltip='tooltip') assert vit.widget.toolTip() == 'tooltip' assert not vit.get() check_property_changed_emit(qtbot, vit, True, qtbot.mouseClick, (vit.widget, Qt.LeftButton), show=True) def test_combo_box_view_item(qtbot): items = ['a', 'b', 'c'] vit = ComboBoxViewItem(items, default_value='b', tooltip='tooltip') assert vit.widget.toolTip() == 'tooltip' assert vit.get() == 'b' check_property_changed_emit(qtbot, vit, 'c', qtbot.keyClick, (vit.widget, 'c')) def test_qline_edit_view_item(qtbot): vit = QLineEditViewItem(default_value='foo', tooltip='tooltip') assert vit.widget.toolTip() == 'tooltip' assert vit.get() == 'foo' check_property_changed_emit(qtbot, vit, 'fooc', qtbot.keyClick, (vit.widget, 'c')) def test_number_qline_edit_view_item(qtbot): with pytest.raises(ValueError): NumberQLineEditViewItem(-100, 100, default_value=1000) with pytest.raises(ValueError): NumberQLineEditViewItem(-100, 100, default_value=-1000) vit = NumberQLineEditViewItem(-100., 100., default_value=0., tooltip='tooltip') assert vit.widget.toolTip().startswith('tooltip') assert vit.get() == 0 # is 0.0, after key click "1" will be 0.01 check_property_changed_emit(qtbot, vit, 0.01, qtbot.keyClick, (vit.widget, '1')) def test_int_qline_edit_view_item(qtbot): with pytest.raises(ValueError): IntQLineEditViewItem(-100, 100, default_value=1000) with pytest.raises(ValueError): IntQLineEditViewItem(-100, 100, default_value=-1000) vit = IntQLineEditViewItem(-100, 100, default_value=0, tooltip='tooltip') assert vit.widget.toolTip().startswith('tooltip') assert vit.get() == 0 # is 0, after key click "1" will be 01, thus 1 check_property_changed_emit(qtbot, vit, 1, qtbot.keyClick, (vit.widget, '1')) def test_range_edit_view_item(qtbot): vit = RangeQLineEditViewItem(default_value=[1.0, 2.0, 3.0], tooltip='tooltip') assert vit.widget.toolTip().startswith('tooltip') assert vit.get() == [1.0, 2.0, 3.0] # Last is 3.0, after key click "1" will be 3.01 check_property_changed_emit(qtbot, vit, [1.0, 2.0, 3.01], qtbot.keyClick, (vit.widget, '1')) class TestPropertyView: def test_init(self, qtbot, property_view): assert len(property_view.property_names) > 0 # Defaults must pass PropertyView() def test_get_property(self, qtbot, property_view): assert property_view.get_property('int') == 10 def test_set_property(self, qtbot, property_view): property_view.set_property('int', 50) assert property_view.get_property('int') == 50 def test_on_property_changed(self, qtbot, property_view): widget = property_view._properties['int'].view_item.widget qtbot.addWidget(widget) qtbot.keyClick(widget, '0') assert property_view.get_property('int') == 100 def test_is_property_visible(self, qtbot, property_view): assert property_view.is_property_visible('int') def test_set_property_visible(self, qtbot, property_view): visible = not property_view.is_property_visible('int') property_view.set_property_visible('int', visible) assert property_view.is_property_visible('int') == visible def test_restore_properties(self, qtbot, property_view): props = property_view.export_properties() property_view.set_property('int', props['int'][0] + 1) property_view.restore_properties(props) assert property_view.get_property('int') == props['int'][0] def test_export_properties(self, qtbot, property_view): props = property_view.export_properties() assert 'int' in props assert props['int'][0] == property_view.get_property('int') assert props['int'][1] == property_view.is_property_visible('int') class TestMultiPropertyView: def test_init(self, qtbot, multi_property_view): assert len(list(iter(multi_property_view))) == 2 def test_getitem(self, qtbot, multi_property_view, nodes): assert multi_property_view['cpm'] == nodes['cpm'].model def test_contains(self, qtbot, multi_property_view): assert 'cpm' in multi_property_view assert 'foo' not in multi_property_view def test_iter(self, qtbot, multi_property_view): assert set(list(iter(multi_property_view))) == set(['cpm', 'Read']) def test_export_groups(self, qtbot, multi_property_view): state = multi_property_view.export_groups() multi_property_view.set_group_visible('Read', False) assert state['Read']['model']['caption'] == 'Read' assert not state['Read']['visible'] def test_restore_groups(self, qtbot, multi_property_view, nodes): multi_property_view['Read']['number'] = 100 state = multi_property_view.export_groups() multi_property_view['Read']['number'] = 1000 multi_property_view.restore_groups(state) assert multi_property_view['Read']['number'] def test_set_group_visible(self, qtbot, multi_property_view): visible = not multi_property_view.is_group_visible('cpm') multi_property_view.set_group_visible('cpm', visible) assert multi_property_view.is_group_visible('cpm') == visible def test_is_group_visible(self, qtbot, multi_property_view): assert multi_property_view.is_group_visible('cpm') assert not multi_property_view.is_group_visible('Read') class TestUfoModel: def test_init(self): model = UfoModel() assert model.caption == model.base_caption def test_restore(self): model = UfoModel() state = {'caption': 'foo'} old_caption = model.caption model.restore(state, restore_caption=False) assert model.caption == old_caption model.restore(state, restore_caption=True) assert model.caption == 'foo' # 'caption' not in state, the old one must be preserved model = UfoModel() old_caption = model.caption model.restore({}, restore_caption=True) assert model.caption == old_caption def save(self): model = UfoModel() model.caption = 'foo' assert model.save()['caption'] == 'foo' class TestPropertyModel: def test_init(self, qtbot): PropertyModel() model = DummyPropertyModel() # make_properties must be called assert set(model.properties) == set(make_properties().keys()) def test_getitem(self, qtbot): model = DummyPropertyModel() model['int'] with pytest.raises(KeyError): model['foo'] def test_setitem(self, qtbot): model = DummyPropertyModel() model['int'] = 132 assert model['int'] == 132 def test_contains(self, qtbot): model = DummyPropertyModel() assert 'int' in model assert 'foo' not in model def test_iter(self, qtbot): model = DummyPropertyModel() assert set(iter(model)) == set(make_properties().keys()) def test_on_property_changed(self, qtbot): def callback(model, name, value): self.called_name = name self.called_value = value model = DummyPropertyModel() model.property_changed.connect(callback) widget = model._view._properties['int'].view_item.widget qtbot.addWidget(widget) qtbot.keyClick(widget, '0') assert self.called_value == model['int'] assert self.called_name == 'int' def test_make_properties(self, qtbot): props = DummyPropertyModel().make_properties() assert props.keys() == make_properties().keys() assert PropertyModel().make_properties() == {} def test_copy_properties(self, qtbot): model = DummyPropertyModel() model['int'] = 123 visible = not model._view.is_property_visible('int') model._view.set_property_visible('int', visible) properties = model.copy_properties() # It has to be a deep copy, so changing the model properties cannot affect the copy model['int'] = 12 model._view.set_property_visible('int', not visible) assert properties['int'][0].get() == 123 assert properties['int'][1] == visible def test_embedded_widget(self, qtbot): assert PropertyModel().embedded_widget() is None assert isinstance(DummyPropertyModel().embedded_widget(), PropertyView) def test_restore(self, qtbot): model = DummyPropertyModel() state = model.save() old_value = model['int'] old_caption = model.caption visible = not model._view.is_property_visible('int') model['int'] = old_value + 1 model._view.set_property_visible('int', visible) model.caption = 'Foo' model.restore(state, restore_caption=False) assert model['int'] == old_value assert model._view.is_property_visible('int') == (not visible) assert model.caption == 'Foo' model.restore(state, restore_caption=True) assert model.caption == old_caption def test_save(self, qtbot): model = DummyPropertyModel() old_value = model['int'] visible = not model._view.is_property_visible('int') model['int'] = old_value + 1 model._view.set_property_visible('int', visible) model.caption = 'Foo' state = model.save() assert state['properties']['int'][0] == old_value + 1 assert state['properties']['int'][1] == visible assert state['caption'] == 'Foo' class TestUfoTaskModel: def test_init(self, qtbot): model = UfoTaskModel('flat-field-correct') assert model.properties # A task doesn't need any special treatment by default assert not model.expects_multiple_inputs assert not model.can_split_gpu_work assert not model.needs_fixed_scheduler def test_make_properties(self, qtbot): model = UfoTaskModel('flat-field-correct') # Config takes effect assert not model._view.is_property_visible('dark-scale') def test_create_ufo_task(self, qtbot): model = UfoTaskModel('flat-field-correct') model['dark-scale'] = 12.3 task = model.create_ufo_task() assert task.props.dark_scale == pytest.approx(12.3) def test_uses_gpu(self, qtbot): model = UfoTaskModel('flat-field-correct') assert model.uses_gpu model = UfoTaskModel('read') assert not model.uses_gpu def test_get_ufo_model_class(qtbot): # flat correction is a fairly complicated task to test task_name = 'flat-field-correct' model_cls = get_ufo_model_class(task_name) # Model class attributes assert model_cls.name == 'flat_field_correct' model = model_cls() # Model instance attributes assert model.num_ports['input'] == 3 assert model.num_ports['output'] == 1 assert model.port_caption['input'][0] == 'radios' assert model.port_caption['input'][1] == 'darks' assert model.port_caption['input'][2] == 'flats' assert model.port_caption['output'][0] == '' class TestBaseCompositeModel: def test_init(self, qtbot, monkeypatch, composite_model, scene): # cpm has 1 input and 2 outputs (read and pad are not connected) and average has 1 input and # 1 output, but cpm is connected with average, which reduces both port types by 1 assert composite_model.num_ports['input'] == 1 assert composite_model.num_ports['output'] == 2 for port_type in ['input', 'output']: for i in range(composite_model.num_ports[port_type]): submodel, j = composite_model.get_model_and_port_index(port_type, i) subcaption = submodel.port_caption[port_type][j] if subcaption: subcaption = ':' + subcaption assert (composite_model.port_caption[port_type][i] == submodel.caption + subcaption) assert composite_model._view # num-inputs must take effect monkeypatch.setattr(QInputDialog, "getInt", lambda *args, **kwargs: (2, True)) monkeypatch.setattr(QInputDialog, "getText", lambda *args, **kwargs: ('with-pr', True)) node = scene.create_node(scene.registry.create('retrieve_phase')) node.graphics_object.setSelected(True) node = scene.create_composite() assert node.model.get_model_from_path(['Retrieve Phase']).num_ports['input'] == 2 # and it must not affect default registry creators kwargs = scene.registry.registered_model_creators()['retrieve_phase'][1] assert 'num_inputs' not in kwargs def test_getitem(self, qtbot, composite_model, nodes): assert composite_model['cpm'] assert composite_model['Average'] with pytest.raises(KeyError): composite_model['foo'] def test_contains(self, qtbot, composite_model, nodes): assert 'cpm' in composite_model assert 'foo' not in composite_model def test_iter(self, qtbot, composite_model): assert set(list(iter(composite_model))) == set(['cpm', 'Average']) def test_get_descendant_graph(self, qtbot, monkeypatch, composite_model, nodes): graph = composite_model.get_descendant_graph() cpm = composite_model['cpm'] assert (composite_model, cpm) in graph.edges assert (composite_model, composite_model['Average']) in graph.edges assert (cpm, cpm['Read']) in graph.edges assert (cpm, cpm['Pad']) in graph.edges with pytest.raises(ValueError): composite_model.get_descendant_graph(in_subwindow=True) # Subwindow editing composite_model.edit_in_window() qtbot.addWidget(composite_model._other_view) graph = composite_model.get_descendant_graph(in_subwindow=True) cpm = composite_model._window_nodes['cpm'].model average = composite_model._window_nodes['Average'].model assert (composite_model, cpm) in graph.edges assert (composite_model, average) in graph.edges assert (cpm, cpm['Read']) in graph.edges assert (cpm, cpm['Pad']) in graph.edges composite_model._other_view.close() # Create outer composite with foobar inside, get_descendant_graph with in_subwindow=True # when outer is being edited and foobar not must return outer subwindow models and foobar's # internal models scene = create_scene(qtbot, composite_model._registry) inner = scene.create_node(composite_model.__class__) monkeypatch.setattr(QInputDialog, "getText", lambda *args: ('outer', True)) inner.graphics_object.setSelected(True) outer = scene.create_composite().model outer.edit_in_window() qtbot.addWidget(outer._other_view) graph = outer.get_descendant_graph(in_subwindow=True) inner = outer._window_nodes['foobar'].model assert (outer, inner) in graph.edges assert (inner, inner['Average']) in graph.edges assert (inner['cpm'], inner['cpm']['Read']) in graph.edges outer._other_view.close() def test_contains_path(self, qtbot, composite_model, nodes): assert composite_model.contains_path(['Average']) assert composite_model.contains_path(['cpm']) assert composite_model.contains_path(['cpm', 'Read']) assert not composite_model.contains_path(['cpm', 'Read 2']) assert not composite_model.contains_path(['foo']) def test_get_model_from_path(self, qtbot, composite_model, nodes): assert composite_model.get_model_from_path(['cpm', 'Read']) with pytest.raises(KeyError): composite_model.get_model_from_path(['foo']) def test_is_model_inside(self, qtbot, composite_model, nodes): model = composite_model.get_model_from_path(['cpm']) assert composite_model.is_model_inside(model) model = composite_model.get_model_from_path(['cpm', 'Read']) assert composite_model.is_model_inside(model) assert not composite_model.is_model_inside(nodes['read_2'].model) def test_get_path_from_model(self, qtbot, composite_model, nodes): cpm = composite_model['cpm'] path = composite_model.get_path_from_model(cpm) assert path == [composite_model, cpm] path = composite_model.get_path_from_model(cpm['Read']) assert path == [composite_model, cpm, cpm['Read']] model = composite_model['cpm']['Read'] path = composite_model.get_path_from_model(model) assert path == [composite_model, cpm, model] with pytest.raises(KeyError): composite_model.get_path_from_model(nodes['read_2'].model) def test_leaf_paths(self, qtbot, composite_model, nodes): leaves = composite_model.get_leaf_paths(in_subwindow=False) cpm = composite_model['cpm'] assert len(leaves) == 3 assert [composite_model, cpm, cpm['Read']] in leaves assert [composite_model, cpm, cpm['Pad']] in leaves assert [composite_model, composite_model['Average']] in leaves def test_set_property_links_model(self, qtbot, link_model, composite_model): composite_model.property_links_model = link_model assert composite_model.property_links_model == link_model # The property links model must be set also for children assert composite_model['cpm'].property_links_model == link_model def test_get_outside_port(self, qtbot, composite_model): # There is one input corresponding to cpm's pad model cpm = composite_model.get_model_from_path(['cpm']) pad_index = cpm.get_outside_port('Pad', 'input', 0)[1] composite_model.get_outside_port('cpm', 'input', pad_index) # and two outputs: cpm's read and average read_index = cpm.get_outside_port('Read', 'output', 0)[1] composite_model.get_outside_port('cpm', 'output', read_index) composite_model.get_outside_port('Average', 'output', 0) def test_get_model_and_port_index(self, qtbot, composite_model): model, index = composite_model.get_model_and_port_index('input', 0) cpm = composite_model.get_model_from_path(['cpm']) # There is only one input: Pad. Get it's internal cpm's index and compare with what the # outer composite object gives. pad_index = cpm.get_outside_port('Pad', 'input', 0)[1] assert model == cpm assert index == pad_index # There are two output ports, one from cpm's read model and one from average average = composite_model.get_model_from_path(['Average']) # Get read index from the cpm inside the composite_model and not from the cpm in the 'nodes' # fixsture because those are not the same instance and the read output index might be # different in those two instances because the ports are dictionaries read_index = cpm.get_outside_port('Read', 'output', 0)[1] outputs = [composite_model.get_model_and_port_index('output', 0)] outputs.append(composite_model.get_model_and_port_index('output', 1)) assert (cpm, read_index) in outputs assert (average, 0) in outputs def test_embedded_widget(self, qtbot, composite_model): assert isinstance(composite_model.embedded_widget(), MultiPropertyView) def test_restore(self, qtbot, composite_model): state = composite_model.save() old_value = composite_model['cpm']['Pad']['width'] old_caption = composite_model.caption visible = not composite_model._view.is_group_visible('cpm') composite_model['cpm']['Pad']['width'] = old_value + 1 composite_model._view.set_group_visible('cpm', visible) composite_model.caption = 'Foo' composite_model.restore(state, restore_caption=False) assert composite_model['cpm']['Pad']['width'] == old_value assert composite_model._view.is_group_visible('cpm') == (not visible) assert composite_model.caption == 'Foo' composite_model.restore(state, restore_caption=True) assert composite_model.caption == old_caption conn = composite_model._connections[0] assert [[conn.from_unique_name, conn.from_port_index, conn.to_unique_name, conn.to_port_index]] == state['connections'] def test_restore_links(self, qtbot, nodes): def check_links(node, link_model): assert link_model.rowCount() == 1 assert link_model.columnCount() == 3 assert link_model.find_items((node.model['cpm']['Read'], 'number'), (MODEL_ROLE, PROPERTY_ROLE)) assert link_model.find_items((node.model['cpm']['Pad'], 'height'), (MODEL_ROLE, PROPERTY_ROLE)) assert link_model.find_items((node.model['Average'], 'number'), (MODEL_ROLE, PROPERTY_ROLE)) assert not link_model.find_items((node.model['cpm']['Read'], 'height'), (MODEL_ROLE, PROPERTY_ROLE)) scene, node = make_composite_node_in_scene(qtbot, nodes) link_model = scene.property_links_model link_model.add_item(node, node.model['cpm']['Read'], 'number', 0, 0) link_model.add_item(node, node.model['cpm']['Pad'], 'height', 0, 1) link_model.add_item(node, node.model['Average'], 'number', 0, 2) # Set links to the newly created links node.model._links = node.model.save()['links'] # Link model has to have the exact same entries as before link_model.clear() node.model.restore_links(node) check_links(node, link_model) # Second time doesn't add the same links twice node.model.restore_links(node) check_links(node, link_model) def test_save(self, qtbot, nodes): scene, node = make_composite_node_in_scene(qtbot, nodes) link_model = scene.property_links_model cpm = node.model['cpm'] link_model.add_item(node, node.model['cpm']['Read'], 'number', 0, 0) link_model.add_item(node, node.model['cpm']['Pad'], 'height', 0, 1) link_model.add_item(node, node.model['Average'], 'number', 0, 2) old_value = node.model['cpm']['Pad']['width'] visible = not node.model._view.is_group_visible('cpm') node.model['cpm']['Pad']['width'] = old_value + 1 node.model._view.set_group_visible('cpm', visible) node.model.caption = 'Foo' state = node.model.save() cpm_models_state = state['models']['cpm']['model']['models'] assert state['models']['cpm']['visible'] == visible assert cpm_models_state['Pad']['model']['properties']['width'][0] == old_value + 1 assert state['caption'] == 'Foo' cpm = node.model.get_model_from_path(['cpm']) pad_index = cpm.get_outside_port('Pad', 'output', 0)[1] assert state['connections'] == [['cpm', pad_index, 'Average', 0]] # Property links links = link_model.get_model_links([path[-1] for path in node.model.get_leaf_paths()]) links = [[str_path[1:] for str_path in row] for row in links.values()] saved = node.model.save()['links'] # One row assert len(saved) == len(links) == 1 # All linked paths must be saved for str_path in saved[0]: assert str_path in links[0] def test_on_connection_created(self, qtbot, composite_model): composite_model.edit_in_window() qtbot.addWidget(composite_model._other_view) for node in composite_model._other_scene.nodes.values(): if node.model.caption == 'cpm': read_index = node.model.get_outside_port('Read', 'output', 0)[1] pad_index = node.model.get_outside_port('Pad', 'input', 0)[1] output_port = node['output'][read_index] input_port = node['input'][pad_index] num_connections = len(composite_model._other_scene.connections) composite_model._other_scene.create_connection(output_port, input_port) # No new connections allowed assert len(composite_model._other_scene.connections) == num_connections composite_model._other_view.close() def test_on_connection_deleted(self, qtbot, composite_model): composite_model.edit_in_window() qtbot.addWidget(composite_model._other_view) num_connections = len(composite_model._other_scene.connections) composite_model._other_scene.delete_connection(composite_model._other_scene.connections[0]) # No connection deletions assert len(composite_model._other_scene.connections) == num_connections composite_model._other_view.close() def test_double_clicked(self, qtbot, composite_model): composite_model.double_clicked(None) qtbot.addWidget(composite_model._other_view) assert composite_model.is_editing and composite_model._other_view is not None def test_on_other_scene_double_clicked(self, qtbot, composite_model): composite_model.double_clicked(None) qtbot.addWidget(composite_model._other_view) for node in composite_model._other_scene.nodes.values(): if node.model.caption == 'cpm': node.model.double_clicked(composite_model._other_view) qtbot.addWidget(node.model._other_view) assert composite_model.is_editing and composite_model._other_view is not None break def test_expand_into_graph(self, qtbot, composite_model): import networkx as nx graph = nx.MultiDiGraph() composite_model.expand_into_graph(graph) src, dst, ports = list(graph.edges.data())[0] conn = composite_model._connections[0] gt = [conn.from_unique_name, conn.from_port_index, conn.to_unique_name, conn.to_port_index] conn_graph = [src.caption, ports['output'], dst.caption, ports['input']] assert conn_graph == gt def test_add_slave_links(self, qtbot, monkeypatch, nodes): def crosscheck(model, root_model, property_name, link_model): key = (model, property_name) root_key = (root_model, property_name) assert link_model._silent[key] == root_key assert key in link_model._slaves[root_key] scene, node = make_composite_node_in_scene(qtbot, nodes) link_model = scene.property_links_model link_model.add_item(node, node.model['cpm']['Read'], 'number', 0, 0) link_model.add_item(node, node.model['cpm']['Pad'], 'height', 0, 1) link_model.add_item(node, node.model['Average'], 'number', 0, 2) # Not being edited, nothing registered node.model.add_slave_links() assert link_model._silent == {} node.model.edit_in_window() qtbot.addWidget(node.model._other_view) # Standard editing setup assert hasattr(node.model, '_other_scene') assert hasattr(node.model, '_other_view') assert not node.model._other_scene.allow_node_creation assert not node.model._other_scene.allow_node_deletion # Test foobar's subwindow, registering model is cpm and its internal models must be linked crosscheck(node.model._window_nodes['cpm'].model['Read'], node.model['cpm']['Read'], 'number', link_model) crosscheck(node.model._window_nodes['cpm'].model['Pad'], node.model['cpm']['Pad'], 'height', link_model) crosscheck(node.model._window_nodes['Average'].model, node.model['Average'], 'number', link_model) # Test foobar's subwindow and cpm's subwindow, cpm and also its models in the subwindow must # be linked cpm = node.model._window_nodes['cpm'].model cpm.edit_in_window() qtbot.addWidget(cpm._other_view) assert cpm.window_parent == node.model crosscheck(cpm._window_nodes['Read'].model, node.model['cpm']['Read'], 'number', link_model) crosscheck(cpm._window_nodes['Pad'].model, node.model['cpm']['Pad'], 'height', link_model) # Add one more composite layer, outer->foobar->cpm->Model, both registering model and its # window_parent must be jinked node.model._other_view.close() assert cpm._other_view is None monkeypatch.setattr(QInputDialog, "getText", lambda *args: ('outermost', True)) node.graphics_object.setSelected(True) outer = scene.create_composite() outer.model.edit_in_window() qtbot.addWidget(outer.model._other_view) node_sub = outer.model._window_nodes['foobar'] node_sub.model.edit_in_window() qtbot.addWidget(node_sub.model._other_view) cpm = node_sub.model._window_nodes['cpm'].model cpm.edit_in_window() qtbot.addWidget(cpm._other_view) crosscheck(node_sub.model._window_nodes['cpm'].model['Read'], outer.model['foobar']['cpm']['Read'], 'number', link_model) crosscheck(node_sub.model._window_nodes['cpm'].model['Pad'], outer.model['foobar']['cpm']['Pad'], 'height', link_model) crosscheck(node_sub.model._window_nodes['Average'].model, outer.model['foobar']['Average'], 'number', link_model) crosscheck(cpm._window_nodes['Read'].model, outer.model['foobar']['cpm']['Read'], 'number', link_model) crosscheck(cpm._window_nodes['Pad'].model, outer.model['foobar']['cpm']['Pad'], 'height', link_model) cpm._other_view.close() node_sub.model._other_view.close() outer.model._other_view.close() def test_edit_in_window(self, qtbot, nodes): composite_model = nodes['cpm'].model link_model = composite_model.property_links_model populate_link_model(link_model, nodes) composite_model.edit_in_window() qtbot.addWidget(composite_model._other_view) assert hasattr(composite_model, '_other_scene') assert hasattr(composite_model, '_other_view') assert not composite_model._other_scene.allow_node_creation assert not composite_model._other_scene.allow_node_deletion # Silent must have been added with root cpm's read model assert (list(link_model._slaves.keys())[0] == (composite_model.get_model_from_path(['Read']), 'y')) # Subcomposites must link to their parent models scene, node = make_composite_node_in_scene(qtbot, nodes) node.model.edit_in_window() qtbot.addWidget(node.model._other_view) assert node.model._window_nodes['cpm'].model.window_parent == node.model node.model._other_view.close() composite_model._other_view.close() def test_view_close_event(self, qtbot, nodes): composite_model = nodes['cpm'].model link_model = composite_model.property_links_model populate_link_model(link_model, nodes) composite_model.edit_in_window() qtbot.addWidget(composite_model._other_view) for node in composite_model._other_scene.nodes.values(): if node.model.caption == 'Read': widget = node.model.embedded_widget()._properties['y'].view_item.widget qtbot.addWidget(node.model.embedded_widget()) qtbot.keyClicks(widget, '11') else: # Pad node.model['width'] += 10 # Linked models must be updated immediately assert composite_model['Read']['y'] == 11 assert nodes['read'].model['number'] == 11 assert nodes['read_2'].model['height'] == 11 composite_model._other_view.close() # Original models in the composite must be updated after close assert composite_model['Pad']['width'] == 10 # Silent model must be removed (it was the only one, so test for {} is sufficient) assert link_model._slaves == {} assert link_model._silent == {} def test_expand_into_scene(self, qtbot, monkeypatch): def get_int(*args, **kwargs): return self.get_int_return monkeypatch.setattr(QInputDialog, "getInt", get_int) nodes = {} registry = get_filled_registry() scene = create_scene(qtbot, registry) # Composite node for name in ['read', 'pad']: model_cls = registry.create(name) node = scene.create_node(model_cls) node.graphics_object.setSelected(True) monkeypatch.setattr(QInputDialog, "getText", lambda *args: ('cpm', True)) nodes['cpm'] = scene.create_composite() nodes['cpm'].graphics_object.setSelected(False) model_cls = registry.create('average') nodes['average'] = scene.create_node(model_cls) self.get_int_return = (2, True) model_cls = registry.create('retrieve_phase') nodes['retrieve_phase'] = scene.create_node(model_cls) # Add null node to create an outside connection null_cls = registry.create('null') null_node = scene.create_node(null_cls) # Make a property link scene.property_links_model.add_item(nodes['cpm'], nodes['cpm'].model['Read'], 'number', 0, 0) scene.property_links_model.add_item(nodes['cpm'], nodes['cpm'].model['Pad'], 'width', 0, 1) # Export composite and reload it so that it remembers the links (important for testing of # adding property link duplicates) cpm_cls_with_links = get_composite_model_classes_from_json(nodes['cpm'].model.save())[0] registry.register_model(cpm_cls_with_links, category='Composites', registry=registry) scene.remove_node(nodes['cpm']) nodes['cpm'] = scene.create_node(registry.create('cpm')) # Outer composite node has inside: read, pad, average; pad and average are connected # read and pad are encapsulated in an internal composite cpm pad_index = nodes['cpm'].model.get_outside_port('Pad', 'output', 0)[1] scene.create_connection(nodes['cpm']['output'][pad_index], nodes['average']['input'][0]) nodes['cpm'].graphics_object.setSelected(True) nodes['average'].graphics_object.setSelected(True) nodes['retrieve_phase'].graphics_object.setSelected(True) monkeypatch.setattr(QInputDialog, "getText", lambda *args: ('foobar', True)) scene.create_composite() composite_node = scene.selected_nodes()[0] composite_model = composite_node.model # Create outside connection from outer composite's average to null port_null = null_node['input'][0] # average_index = nodes['cpm'].model.get_outside_port('Pad', 'output', 0)[1] # Get the average index dynamically because it might be mapped to a different output port # every time (reader in cpm makes another output) average_index = composite_model.get_outside_port('Average', 'output', 0)[1] port_composite = composite_node['output'][average_index] scene.create_connection(port_composite, port_null) # Change some property to see if it persists after expansion composite_model['cpm']['Read']['number'] = 123 # Make sure the nested num-inputs takes effect, i.e. QInputDialog.getInt invocation must # fail the test self.get_int_return = (None, False) composite_model.expand_into_scene(scene, composite_node) # Nodes must be there assert (set([node.model.caption for node in scene.nodes.values()]) == set(['Null', 'Average', 'Retrieve Phase', 'cpm'])) # num-inputs took effect for node in scene.nodes.values(): if node.model.caption == 'Retrieve Phase': assert node.model.num_ports['input'] == 2 break # Changed properties must be there for node in scene.nodes.values(): if node.model.caption == 'cpm': assert node.model['Read']['number'] == 123 break # Connections must be preserved for connection in scene.connections: if connection.get_node('output').model.caption == 'cpm': # Internal composite connection Pad -> Average must be there assert connection.get_node('input').model.caption == 'Average' cpm_index = connection.get_port_index('output') cpm_model = connection.get_node('output').model assert cpm_model.get_model_and_port_index('output', cpm_index)[0].caption == 'Pad' else: # Outside connection Average -> Null must be there assert connection.get_node('input').model.caption == 'Null' # Property links must be there assert scene.property_links_model.rowCount() == 1 # Original composite node must be gone assert composite_node not in scene.nodes.values() def test_get_composite_model_class(qtbot, nodes): model_cls = make_composite_model_class(nodes) with pytest.raises(AttributeError): # Registry must be provided model_cls() # Name must be provided with pytest.raises(UfoModelError): make_composite_model_class(nodes, name='') with pytest.raises(UfoModelError): make_composite_model_class(nodes, name=None) class TestUfoGeneralBackprojectModel: def test_init(self, general_backproject): assert general_backproject.num_ports['input'] == 1 assert general_backproject.num_ports['output'] == 1 assert general_backproject.needs_fixed_scheduler is True assert general_backproject.can_split_gpu_work is True def test_make_properties(self, general_backproject): props = general_backproject.make_properties() assert 'slice-memory-coeff' in props def test_split_gpu_work(self, general_backproject): from gi.repository import Ufo resources = Ufo.Resources() gpus = resources.get_gpu_nodes() general_backproject['x-region'] = [-100., 100., 1.] general_backproject['y-region'] = [-100., 100., 1.] general_backproject['region'] = [-100., 100., 1.] if gpus: # Normal operation assert general_backproject.split_gpu_work(gpus) # Wrong input general_backproject['x-region'] = [-100., -200., 1.] with pytest.raises(UfoModelError): general_backproject.split_gpu_work(gpus) general_backproject['x-region'] = [-100., 100., 1.] general_backproject['y-region'] = [-100., -200., 1.] with pytest.raises(UfoModelError): general_backproject.split_gpu_work(gpus) general_backproject['y-region'] = [-100., 100., 1.] general_backproject['region'] = [-100., -200., 1.] with pytest.raises(UfoModelError): general_backproject.split_gpu_work(gpus) general_backproject['region'] = [-100., 100., 1.] def test_create_ufo_task(self, general_backproject): general_backproject['region'] = [-100., 100., 1.] ufo_task = general_backproject.create_ufo_task(region=None) assert ufo_task.props.region == pytest.approx(general_backproject['region']) ufo_task = general_backproject.create_ufo_task(region=[-10., 10., 1.]) assert ufo_task.props.region == pytest.approx([-10., 10., 1.]) class TestUfoReadModel: def test_init(self, read_model): assert read_model.num_ports['input'] == 0 assert read_model.num_ports['output'] == 1 def test_double_clicked(self, qtbot, monkeypatch, read_model): from tofu.flow.filedirdialog import FileDirDialog monkeypatch.setattr(FileDirDialog, "exec_", lambda *args: 1) monkeypatch.setattr(FileDirDialog, "selectedFiles", lambda *args: ['foobarbaz']) read_model.double_clicked(None) assert read_model['path'] == 'foobarbaz' class TestUfoVaryingInputModel: def test_init(self, qtbot, monkeypatch): def get_int(*args, **kwargs): self.called = True return (1, True) # No number of inputs specified, dialog needs to pop up self.called = False monkeypatch.setattr(QInputDialog, 'getInt', get_int) model = UfoVaryingInputModel('opencl', num_inputs=None) qtbot.addWidget(model.embedded_widget()) assert self.called assert model.num_ports['input'] == 1 # e.g. opencl task can have multiple inputs model = UfoVaryingInputModel('opencl', num_inputs=4) qtbot.addWidget(model.embedded_widget()) assert model.num_ports['input'] == 4 assert len(model.data_type['input']) == 4 assert len(model.port_caption['input']) == 4 assert len(model.port_caption_visible['input']) == 4 def test_save(self, qtbot): model = UfoVaryingInputModel('opencl', num_inputs=4) qtbot.addWidget(model.embedded_widget()) assert model.save()['num-inputs'] == 4 class TestUfoRetrievePhaseModel: def test_distance_input(self, qtbot): model = UfoRetrievePhaseModel(num_inputs=4) qtbot.addWidget(model.embedded_widget()) validator = model._view._properties['distance'].view_item.widget.validator() # Validator accepts only 4 values assert validator.validate('1,2,3,4', 0)[0] == QValidator.Acceptable assert validator.validate('1,2,3', 0)[0] == QValidator.Intermediate assert validator.validate('1,2,3,4,5', 0)[0] == QValidator.Invalid def test_multidistance_fixed_method(self, qtbot): def check(num_inputs): model = UfoRetrievePhaseModel(num_inputs=num_inputs) qtbot.addWidget(model.embedded_widget()) enabled = num_inputs == 1 assert model._view._properties['method'].view_item.widget.isEnabled() == enabled if not enabled: assert model['method'] == 'ctf_multidistance' assert model._view._properties['distance-x'].view_item.widget.isEnabled() == enabled assert model._view._properties['distance-y'].view_item.widget.isEnabled() == enabled check(1) check(2) class TestUfoWriteModel: def test_init(self, write_model): assert write_model.num_ports['input'] == 1 assert write_model.num_ports['output'] == 0 def test_double_clicked(self, monkeypatch, write_model): monkeypatch.setattr(QFileDialog, "getSaveFileName", lambda *args: ('foobarbaz', None)) write_model.double_clicked(None) assert write_model['filename'] == 'foobarbaz' def test_expects_multiple_inputs(self, write_model): write_model['filename'] = 'foo{region}bar' assert write_model.expects_multiple_inputs write_model['filename'] = 'foobar' assert not write_model.expects_multiple_inputs def test_setup_ufo_task(self, write_model): write_model['filename'] = '{region}' # Must pass ufo_task = write_model.create_ufo_task(region=[0, 1, 1]) # Must fail with pytest.raises(UfoModelError): write_model.create_ufo_task(region=None) assert ufo_task.props.filename == '0' write_model['filename'] = 'foo.tif' # Must pass ufo_task = write_model.create_ufo_task(region=None) # Must fail with pytest.raises(UfoModelError): write_model.create_ufo_task(region=[0, 1, 1]) assert ufo_task.props.filename == 'foo.tif' class TestUfoMemoryOutModel: def test_init(self, memory_out_model): assert memory_out_model.num_ports['input'] == 1 assert memory_out_model.num_ports['output'] == 1 def test_expects_multiple_inputs(self, memory_out_model): memory_out_model['number'] = '{region}' assert memory_out_model.expects_multiple_inputs memory_out_model['number'] = '1' assert not memory_out_model.expects_multiple_inputs def test_make_properties(self, memory_out_model): prop_names = {'width', 'height', 'depth', 'number'} assert prop_names == memory_out_model.make_properties().keys() def test_out_data(self, monkeypatch, memory_out_model): def slot(port_index): self.num_called += 1 self.data = memory_out_model.out_data(port_index) self.num_called = 0 memory_out_model['number'] = 10 shape = (int(memory_out_model['number']), memory_out_model['height'], memory_out_model['width']) memory_out_model.create_ufo_task() batch = memory_out_model._batches[0] memory_out_model.data_updated.connect(slot) batch.data[:] = 3 assert len(memory_out_model._batches) == 1 assert batch.data.shape == shape for i in range(shape[0]): batch._on_processed(None) # Called once per 3D array assert self.num_called == 1 # out_data has been set to the batch ouput np.testing.assert_almost_equal(self.data, 3) # Original data must have been freed assert memory_out_model._batches == [None] memory_out_model.reset_batches() # Multiple inputs def slot(port_index): # Append the first item in the current result self.called.append(memory_out_model.out_data(port_index)[0, 0, 0]) self.called = [] memory_out_model.data_updated.connect(slot) memory_out_model['number'] = '{region}' # Two parallel batches of four regions each for j in range(2): for i in range(4): memory_out_model.create_ufo_task(region=[0, 10, 1]) # Set batch data to its linearized index to make checking easy memory_out_model._batches[4 * j + i].data[:] = 4 * j + i # Out of order processing for batch_id in np.array([2, 0, 1, 3], dtype=np.int) + (4 * j): for e in range(10): memory_out_model._batches[batch_id]._on_processed(None) # All regions in the current paralell batch must have been processed assert memory_out_model._waiting_list == [] # Result must be in order np.testing.assert_almost_equal(self.called, np.arange(8)) # Original data must have been freed assert memory_out_model._batches == [None] * 8 def test_reset_batches(self, memory_out_model): memory_out_model.reset_batches() assert memory_out_model._batches == [] assert memory_out_model._waiting_list == [] assert memory_out_model._expecting_id == 0 assert memory_out_model._current_data is None def test_setup_ufo_task(self, memory_out_model): memory_out_model['number'] = '{region}' # Must pass memory_out_model.create_ufo_task(region=[0, 100, 1]) memory_out_model.create_ufo_task(region=[100, 200, 1]) assert len(memory_out_model._batches) == 2 # Must fail with pytest.raises(UfoModelError): memory_out_model.create_ufo_task(region=None) memory_out_model.reset_batches() memory_out_model['number'] = '100' # Must pass memory_out_model.create_ufo_task(region=None) # Must fail with pytest.raises(UfoModelError): memory_out_model.create_ufo_task(region=[0, 100, 1]) assert len(memory_out_model._batches) == 1 class TestImageViewerModel: def test_init(self, image_viewer_model): assert image_viewer_model.num_ports['input'] == 1 assert image_viewer_model.num_ports['output'] == 0 def test_double_clicked(self, qtbot, image_viewer_model): image_viewer_model.double_clicked(None) # No images, no pop up assert image_viewer_model._widget._pg_window is None image_viewer_model._widget.images = np.arange(1000).reshape(10, 10, 10) image_viewer_model.double_clicked(None) assert image_viewer_model._widget._pg_window.isVisible() qtbot.addWidget(image_viewer_model._widget._pg_window) # User closes, must re-open image_viewer_model._widget._pg_window.close() image_viewer_model.double_clicked(None) assert image_viewer_model._widget._pg_window.isVisible() def test_set_in_data(self, image_viewer_model): images = np.arange(1000).reshape(10, 10, 10) image_viewer_model.set_in_data(images, None) assert image_viewer_model._widget.images.shape == images.shape image_viewer_model.set_in_data(images, None) assert image_viewer_model._widget.images.shape == (20,) + images.shape[1:] # Images cannot be appended after reset is called, they must be set image_viewer_model.reset_batches() image_viewer_model.set_in_data(images, None) assert image_viewer_model._widget.images.shape == images.shape def test_reset_batches(self, image_viewer_model): image_viewer_model.reset_batches() assert image_viewer_model._reset def test_get_ufo_model_classes(): # All classes = list(get_ufo_model_classes()) assert classes # Blacklist assert 'read' not in [cls.name for cls in classes] # Selection assert len(list(get_ufo_model_classes(names=['pad']))) == 1 def test_get_composite_model_classes_from_json(qtbot, composite_model): classes = get_composite_model_classes_from_json(composite_model.save()) # First must be the bottom class, top class comes last assert [cls.name for cls in classes] == ['cpm', 'foobar'] def test_get_composite_model_classes(): # Just make sure this runs and the result is not empty assert get_composite_model_classes() tofu-0.12.0/tofu/tests/test_flow_propertylinksmodels.py000066400000000000000000000472631423713721100234740ustar00rootroot00000000000000import pytest from qtpy.QtCore import QByteArray, QMimeData, QModelIndex from tofu.flow.propertylinksmodels import _get_string_path from tofu.flow.propertylinkswidget import _encode_mime_data from tofu.flow.util import MODEL_ROLE, NODE_ROLE, PROPERTY_ROLE from tofu.tests.flow_util import get_index_from_treemodel, populate_link_model def setup_silent(link_model, nodes): read = nodes['read'] read_2 = nodes['read_2'] composite = nodes['cpm'] orig_key = (read.model, 'number') link_model.add_item(read, read.model, 'number', -1, -1) link_model.add_silent(composite.model['Read'], 'number', orig_key[0], orig_key[1]) link_model.add_silent(read_2.model, 'height', orig_key[0], orig_key[1]) # Put to 0 to make sure we are not lucky when checking if the links work composite.model['Read']['number'] = 0 read_2.model['height'] = 0 return orig_key class TestNodeTreeModel: def test_add_node(self, qtbot, node_model, nodes): # Unsupported model type not added node_model.add_node(nodes['image_viewer']) assert node_model.rowCount() == 0 # Supported model type (composite is handled in test_add_node) node_model.add_node(nodes['read']) assert node_model.rowCount() == 1 # Composite node_model.add_node(nodes['cpm']) item = node_model.findItems('cpm')[0] # Model contains composite node assert item.data(role=NODE_ROLE) == nodes['cpm'] # and it's children assert item.child(0).data(role=MODEL_ROLE) == nodes['cpm'].model['Pad'] assert item.child(1).data(role=MODEL_ROLE) == nodes['cpm'].model['Read'] # and their properties assert item.child(0).child(0).text() == sorted(nodes['cpm'].model['Pad'])[0] def test_remove_node(self, qtbot, node_model, nodes): node_model.add_node(nodes['cpm']) assert node_model.rowCount() == 1 node_model.remove_node(nodes['cpm']) assert node_model.rowCount() == 0 def test_set_nodes(self, qtbot, node_model, nodes): names = ['cpm', 'read'] subset = [nodes[key] for key in names] node_model.set_nodes(subset) for (i, key) in enumerate(names): assert node_model.item(i).data(role=NODE_ROLE) == nodes[key] def test_clear(self, qtbot, node_model, nodes): node_model.set_nodes(nodes.values()) assert node_model.rowCount() > 0 assert node_model.columnCount() > 0 node_model.clear() assert node_model.rowCount() == 0 assert node_model.columnCount() == 0 class TestPropertyLinksModel: def test_add_item(self, qtbot, link_model, nodes): read = nodes['read'] composite = nodes['cpm'] composite.model.property_links_model = link_model # Put to 0 to make sure we are not lucky below when checking if the links work composite.model['Read']['number'] = 0 # Items must be added link_model.add_item(read, read.model, 'number', -1, -1) item = link_model.item(0, 0) assert item.data(role=NODE_ROLE) == read assert item.data(role=MODEL_ROLE) == read.model assert item.data(role=PROPERTY_ROLE) == 'number' link_model.add_item(composite, composite.model['Read'], 'number', 0, -1) item = link_model.item(0, 1) assert item.data(role=NODE_ROLE) == composite assert item.data(role=MODEL_ROLE) == composite.model['Read'] assert item.data(role=PROPERTY_ROLE) == 'number' # Can't add one item twice with pytest.raises(ValueError): link_model.add_item(read, read.model, 'number', -1, -1) # Properties must be linked read.model['number'] = 100 read.model.property_changed.emit(read.model, 'number', read.model['number']) assert composite.model['Read']['number'] == read.model['number'] # When composite is being added, make sure the slave links are set up link_model.remove_item(link_model.find_items([composite], [NODE_ROLE])[0]) composite.model.edit_in_window() qtbot.addWidget(composite.model._other_view) link_model.add_item(composite, composite.model['Read'], 'number', 0, -1) key = (composite.model._window_nodes['Read'].model, 'number') root_key = (composite.model['Read'], 'number') assert link_model._slaves[root_key] == [key] assert link_model._silent[key] == root_key def test_remove_item(self, qtbot, link_model, nodes): read = nodes['read'] read_2 = nodes['read_2'] composite = nodes['cpm'] link_model.add_item(read, read.model, 'number', -1, -1) link_model.add_item(read_2, read_2.model, 'number', 0, -1) link_model.add_silent(composite.model['Read'], 'number', read.model, 'number') # Properties must be connected at first read.model['number'] = 100 read.model.property_changed.emit(read.model, 'number', read.model['number']) assert read_2.model['number'] == read.model['number'] link_model.remove_item(link_model.indexFromItem(link_model.item(0, 0))) assert link_model.item(0, 0) is None assert link_model._silent == {} assert link_model._slaves == {} # Properties must be disconnected after removal read.model['number'] = 0 read.model.property_changed.emit(read.model, 'number', read.model['number']) # read_2 still at the old 100 assert read_2.model['number'] == 100 def test_contains(self, qtbot, link_model, nodes): composite = nodes['cpm'] link_model.add_item(composite, composite.model['Read'], 'number', 0, -1) assert link_model.item(0, 0).text() in link_model assert 'foo' not in link_model def test_clear(self, qtbot, link_model, nodes): read = nodes['read'] read_2 = nodes['read_2'] composite = nodes['cpm'] link_model.add_item(read, read.model, 'number', -1, -1) link_model.add_item(read_2, read_2.model, 'number', 0, -1) link_model.add_silent(composite.model['Read'], 'number', read.model, 'number') link_model.clear() assert link_model.rowCount() == 0 assert link_model.columnCount() == 0 assert link_model._silent == {} assert link_model._slaves == {} def test_find_items(self, qtbot, link_model, nodes): read = nodes['read'] read_2 = nodes['read_2'] # Empty model assert link_model.find_items([read.model], [MODEL_ROLE]) == [] link_model.add_item(read, read.model, 'number', -1, -1) # Not inside assert link_model.find_items([read_2.model], [MODEL_ROLE]) == [] # Inside assert (link_model.find_items([read.model], [MODEL_ROLE])[0].data(role=MODEL_ROLE) == read.model) # Model not inside, property not inside assert link_model.find_items((read_2.model, 'height'), (MODEL_ROLE, PROPERTY_ROLE)) == [] # Model inside, property not inside assert link_model.find_items((read.model, 'height'), (MODEL_ROLE, PROPERTY_ROLE)) == [] # Model not inside, property inside assert link_model.find_items((read_2.model, 'number'), (MODEL_ROLE, PROPERTY_ROLE)) == [] # Model inside, property inside item = link_model.find_items((read.model, 'number'), (MODEL_ROLE, PROPERTY_ROLE))[0] assert item.data(role=MODEL_ROLE) == read.model assert item.data(role=PROPERTY_ROLE) == 'number' def test_get_model_links(sef, qtbot, link_model, nodes): populate_link_model(link_model, nodes) assert link_model.get_model_links(nodes['read_3'].model) == {} links = link_model.get_model_links([nodes['read'].model, nodes['read_2'].model, nodes['cpm'].model['Read']]) links = list(links.values()) # Just one row assert len(links) == 1 # Three items in that row assert len(links[0]) == 3 assert [nodes['read'].model.caption, 'number'] in links[0] assert [nodes['read_2'].model.caption, 'height'] in links[0] path = nodes['cpm'].model.get_path_from_model(nodes['cpm'].model['Read']) str_path = [model.caption for model in path] + ['y'] assert str_path in links[0] def test_get_root_model(self, qtbot, link_model, nodes): read = nodes['read'] composite = nodes['cpm'] link_model.add_item(read, read.model, 'number', -1, -1) # Not inside assert link_model.get_root_model(nodes['read_2'].model) is None # Directly inside assert link_model.get_root_model(read.model) == read.model # Indirectly inside via silent link_model.add_silent(composite.model['Read'], 'number', read.model, 'number') assert link_model.get_root_model(composite.model['Read']) == read.model def test_get_model_properties(self, qtbot, link_model, nodes): read = nodes['read'] link_model.add_item(read, read.model, 'number', -1, -1) link_model.add_item(read, read.model, 'height', -1, -1) # Empty assert link_model.get_model_properties(nodes['read_2'].model) == [] # Multiple assert set(link_model.get_model_properties(read.model)) == set(['number', 'height']) def test_add_silent(self, qtbot, link_model, nodes): read = nodes['read'] read_2 = nodes['read_2'] composite = nodes['cpm'] orig_key = setup_silent(link_model, nodes) # orig model not inside with pytest.raises(ValueError): link_model.add_silent(composite.model['Read'], 'height', nodes['read_3'].model, 'number') # source property not inside with pytest.raises(ValueError): link_model.add_silent(composite.model['Read'], 'height', read.model, 'height') # Links inside assert len(link_model._slaves[orig_key]) == 2 key = (composite.model['Read'], 'number') assert link_model._silent[key] == orig_key assert key in link_model._slaves[orig_key] key = (read_2.model, 'height') assert link_model._silent[key] == orig_key assert key in link_model._slaves[orig_key] # Properties conected read.model['number'] = 100 read.model.property_changed.emit(read.model, 'number', read.model['number']) assert composite.model['Read']['number'] == read.model['number'] assert read_2.model['height'] == read.model['number'] def test_remove_silent(self, qtbot, link_model, nodes): read = nodes['read'] read_2 = nodes['read_2'] composite = nodes['cpm'] orig_key = setup_silent(link_model, nodes) key = (composite.model['Read'], 'number') link_model.remove_silent(*key) assert key not in link_model._silent # Silent link disconected read.model['number'] = 100 read.model.property_changed.emit(read.model, 'number', read.model['number']) assert composite.model['Read']['number'] == 0 assert read_2.model['height'] == read.model['number'] # No more slaves, remove the original key as well key = (nodes['read_2'].model, 'height') link_model.remove_silent(*key) assert orig_key not in link_model._slaves def test_replace_item(self, qtbot, link_model, nodes): read = nodes['read'] read_2 = nodes['read_2'] composite = nodes['cpm'] orig_key = setup_silent(link_model, nodes) replacer = nodes['read_3'] item = link_model.find_items(orig_key, (MODEL_ROLE, PROPERTY_ROLE))[0] (row, column) = item.row(), item.column() link_model.replace_item(replacer, replacer.model, orig_key[0]) new_item = link_model.item(row, column) assert new_item.data(role=MODEL_ROLE) == replacer.model # Silent links re-connected # This must have no effect on silent models read.model['number'] = 100 read.model.property_changed.emit(read.model, 'number', read.model['number']) assert composite.model['Read']['number'] == 0 assert read_2.model['height'] == 0 # This must change silent models' properties replacer.model['number'] = 100 replacer.model.property_changed.emit(replacer.model, 'number', replacer.model['number']) assert composite.model['Read']['number'] == replacer.model['number'] assert read_2.model['height'] == replacer.model['number'] def test_on_node_rows_about_to_be_removed(self, qtbot, link_model, node_model, nodes): read = nodes['read'] read_2 = nodes['read_2'] read_3 = nodes['read_3'] node_model.add_node(read) node_model.add_node(read_2) node_model.add_node(read_3) link_model.add_item(read, read.model, 'number', -1, -1) link_model.add_item(read_2, read_2.model, 'number', 0, -1) link_model.add_item(read_3, read_3.model, 'number', -1, -1) # Remove one node_model.removeRow(0) assert link_model.find_items([read.model], [MODEL_ROLE]) == [] # Remove all node_model.clear() assert link_model.find_items([read_2.model], [MODEL_ROLE]) == [] assert link_model.find_items([read_3.model], [MODEL_ROLE]) == [] def test_canDropMimeData(self, qtbot, link_model, node_model, nodes): read = nodes['read'] read_2 = nodes['read_2'] node_model.add_node(read) node_model.add_node(read_2) # Incompatible QMimeData data = QMimeData() data.setData('application/x-foobar', QByteArray()) assert not link_model.canDropMimeData(data, None, -1, -1, QModelIndex()) # No parent index = get_index_from_treemodel(node_model, 0, 'number') data = _encode_mime_data(index) assert link_model.canDropMimeData(data, None, -1, -1, QModelIndex()) link_model.add_item(read, read.model, 'number', -1, -1) assert not link_model.canDropMimeData(data, None, -1, -1, QModelIndex()) # On parent # Compatible property type index = get_index_from_treemodel(node_model, 1, 'number') data = _encode_mime_data(index) parent = link_model.indexFromItem(link_model.item(0, 0)) assert link_model.canDropMimeData(data, None, 0, 0, parent) # Incompatible property type index = get_index_from_treemodel(node_model, 1, 'path') data = _encode_mime_data(index) parent = link_model.indexFromItem(link_model.item(0, 0)) assert not link_model.canDropMimeData(data, None, 0, 0, parent) def test_dropMimeData(self, qtbot, link_model, node_model, nodes): read = nodes['read'] read_2 = nodes['read_2'] node_model.add_node(read) node_model.add_node(read_2) # No parent index = get_index_from_treemodel(node_model, 0, 'number') data = _encode_mime_data(index) link_model.dropMimeData(data, None, -1, -1, QModelIndex()) item = link_model.item(0, 0) assert item.data(role=NODE_ROLE) == read assert item.data(role=MODEL_ROLE) == read.model assert item.data(role=PROPERTY_ROLE) == 'number' # On parent index = get_index_from_treemodel(node_model, 1, 'number') data = _encode_mime_data(index) parent = link_model.indexFromItem(link_model.item(0, 0)) link_model.dropMimeData(data, None, -1, -1, parent) item = link_model.item(0, 1) assert item.data(role=NODE_ROLE) == read_2 assert item.data(role=MODEL_ROLE) == read_2.model assert item.data(role=PROPERTY_ROLE) == 'number' def test_save(self, qtbot, link_model, nodes): records = populate_link_model(link_model, nodes) for (i, (node_id, str_path)) in enumerate(link_model.save()[0]): assert node_id == records[i][0].id path = _get_string_path(records[i][0], records[i][1], records[i][2]) assert str_path == path def test_restore(self, qtbot, link_model, nodes): records = populate_link_model(link_model, nodes) state = link_model.save() link_model.clear() # Add new item read_3 = nodes['read_3'] link_model.add_item(read_3, read_3.model, 'number', -1, -1) link_model.restore(state, {node.id: node for node in nodes.values()}) assert link_model.columnCount() == 3 for column in range(link_model.columnCount()): item = link_model.item(0, column) assert item.data(role=NODE_ROLE) == records[column][0] assert item.data(role=MODEL_ROLE) == records[column][1] assert item.data(role=PROPERTY_ROLE) == records[column][2] # Restore must clear whatever is inside assert link_model.find_items([read_3.model], [MODEL_ROLE]) == [] def test_compact(self, qtbot, link_model, nodes): read = nodes['read'] read_2 = nodes['read_2'] read_3 = nodes['read_3'] read_4 = nodes['read_4'] def populate(): link_model.add_item(read, read.model, 'number', 0, 0) link_model.add_item(read_2, read_2.model, 'number', 0, 1) link_model.add_item(read_3, read_3.model, 'number', 1, 0) link_model.add_item(read_4, read_4.model, 'number', 1, 1) def check(row_count, column_count): assert link_model.rowCount() == row_count assert link_model.columnCount() == column_count populate() link_model.remove_item(link_model.indexFromItem(link_model.item(0, 1))) link_model.compact() check(2, 2) link_model.clear() # Shift item to the left to an unused cell populate() link_model.remove_item(link_model.indexFromItem(link_model.item(0, 0))) link_model.compact() assert link_model.item(0, 0).data(role=NODE_ROLE) == read_2 check(2, 2) # Nothing in the row, remove it link_model.remove_item(link_model.indexFromItem(link_model.item(0, 0))) link_model.compact() check(1, 2) # Remove column 0 and shift 1st column to the left link_model.clear() populate() link_model.remove_item(link_model.indexFromItem(link_model.item(0, 0))) link_model.remove_item(link_model.indexFromItem(link_model.item(1, 0))) link_model.compact() assert link_model.item(0, 0).data(role=NODE_ROLE) == read_2 assert link_model.item(1, 0).data(role=NODE_ROLE) == read_4 check(2, 1) def test_on_property_changed(self, qtbot, link_model, nodes): composite = nodes['cpm'] read = nodes['read'] read_2 = nodes['read_2'] read_3 = nodes['read_3'] read_4 = nodes['read_4'] # Read 2->height and cpm->Read->number are silent dependend on Read->number setup_silent(link_model, nodes) # Put every linked property to 0 to make sure we are not lucky when checking if the links # work read_3.model['number'] = 0 read_4.model['number'] = 0 composite.model['Pad']['width'] = 0 link_model.add_item(read_3, read_3.model, 'number', 0, 1) link_model.add_item(read_4, read_4.model, 'number', 1, 0) link_model.add_item(composite, composite.model['Pad'], 'width', 1, 1) read.model['number'] = 100 read.model.property_changed.emit(read.model, 'number', read.model['number']) # Row 0 # Direct link assert read_3.model['number'] == read.model['number'] # Silent links assert read_2.model['height'] == read.model['number'] assert composite.model['Read']['number'] == read.model['number'] # Row 1 read_4.model['number'] = 100 read_4.model.property_changed.emit(read_4.model, 'number', read_4.model['number']) assert composite.model['Pad']['width'] == read_4.model['number'] tofu-0.12.0/tofu/tests/test_flow_propertylinkswidget.py000066400000000000000000000032251423713721100234620ustar00rootroot00000000000000import pytest from PyQt5.QtCore import Qt, QItemSelectionModel from tofu.flow.propertylinkswidget import NodesView, PropertyLinks, PropertyLinksView from tofu.tests.flow_util import get_index_from_treemodel, populate_link_model @pytest.fixture(scope='function') def node_view(node_model): view = NodesView() view.setHeaderHidden(True) view.setAlternatingRowColors(True) view.setDragEnabled(True) view.setAcceptDrops(False) view.setModel(node_model) return view @pytest.fixture(scope='function') def link_view(): return PropertyLinksView() @pytest.fixture(scope='function') def link_widget(node_model): return PropertyLinks() def test_property_links_view_delete_key(qtbot, link_model, link_view, nodes): qtbot.addWidget(link_view) link_view.setModel(link_model) populate_link_model(link_model, nodes) link_view.selectColumn(0) qtbot.keyPress(link_view, Qt.Key_Delete) assert link_model.columnCount() == 2 def test_node_view_get_drag_index(qtbot, node_view, nodes): node_model = node_view.model() read = nodes['read'] node_model.add_node(read) sm = node_view.selectionModel() # Nothing selected assert node_view.get_drag_index() is None # Node selection must yield nothing index = node_model.indexFromItem(node_model.item(0, 0)) sm.select(index, QItemSelectionModel.Select) assert node_view.get_drag_index() is None sm.clear() # Property selection must yield an index which can be dragged index = get_index_from_treemodel(node_model, 0, 'number') sm.select(index, QItemSelectionModel.Select) assert node_view.get_drag_index() is not None sm.clear() tofu-0.12.0/tofu/tests/test_flow_runslider.py000066400000000000000000000147101423713721100213410ustar00rootroot00000000000000import pytest from tofu.flow.runslider import RunSlider, RunSliderError @pytest.fixture(scope='function') def runslider(qtbot, scene): slider = RunSlider() node = scene.create_node(scene.registry.create('filter')) slider.setup(node.model._view._properties['cutoff'].view_item) qtbot.addWidget(slider) return slider class TestRunSlider: def test_setup(self, qtbot, runslider): assert not runslider.setup(runslider.view_item) bottom = runslider.view_item.widget.validator().bottom() top = runslider.view_item.widget.validator().top() assert runslider.type == float assert runslider.real_minimum == bottom assert runslider.real_maximum == top assert float(runslider.min_edit.text()) == bottom assert float(runslider.max_edit.text()) == top assert float(runslider.current_edit.text()) == runslider.view_item.get() assert runslider.slider.value() / 100 + runslider.real_minimum == runslider.view_item.get() assert runslider.isEnabled() def test_reset(self, qtbot, runslider): runslider.reset() assert runslider.view_item is None assert runslider.type is None assert runslider.real_minimum == 0 assert runslider.real_maximum == 100 assert runslider.real_span == 100 assert runslider.min_edit.text() == '' assert runslider.max_edit.text() == '' assert runslider.current_edit.text() == '' assert not runslider.isEnabled() def test_empty(self, qtbot, runslider): runslider.reset() runslider.on_min_edit_editing_finished() runslider.on_max_edit_editing_finished() runslider.on_current_edit_editing_finished() def test_min_edit_changed(self, qtbot, runslider): top = runslider.view_item.widget.validator().top() with pytest.raises(RunSliderError): runslider.min_edit.setText('asdf') runslider.on_min_edit_editing_finished() with pytest.raises(RunSliderError): runslider.min_edit.setText(str(top + 1)) runslider.on_min_edit_editing_finished() # Current value lower than new minimum, must be updated value = runslider.get_real_value() + 0.1 runslider.min_edit.setText(str(value)) runslider.on_min_edit_editing_finished() assert value == runslider.get_real_value() def test_max_edit_changed(self, qtbot, runslider): bottom = runslider.view_item.widget.validator().bottom() with pytest.raises(RunSliderError): runslider.max_edit.setText('asdf') runslider.on_max_edit_editing_finished() with pytest.raises(RunSliderError): runslider.max_edit.setText(str(bottom - 1)) runslider.on_max_edit_editing_finished() # Current value greater than new maximum, must be updated value = runslider.get_real_value() - 0.1 runslider.max_edit.setText(str(value)) runslider.on_max_edit_editing_finished() assert value == runslider.get_real_value() def test_current_edit_changed(self, qtbot, runslider): self.value_changed_value = None def on_value_changed(value): self.value_changed_value = value # Nothing changed, no update triggered runslider.on_current_edit_editing_finished() assert self.value_changed_value is None runslider.value_changed.connect(on_value_changed) current = runslider.get_real_value() + 0.1 runslider.current_edit.setText(str(current)) runslider.on_current_edit_editing_finished() assert runslider.get_real_value() == current runslider.current_edit.setText('asf') with pytest.raises(RunSliderError): runslider.on_current_edit_editing_finished() def test_int(self, qtbot, runslider, scene): node = scene.create_node(scene.registry.create('read')) runslider.setup(node.model._view._properties['y'].view_item) assert runslider.type == int runslider.min_edit.setText('1') runslider.on_min_edit_editing_finished() runslider.max_edit.setText('10') runslider.on_max_edit_editing_finished() runslider.slider.setValue(50) assert type(runslider.get_real_value()) == int assert runslider.get_real_value() == 5 runslider.current_edit.setText('7') runslider.on_current_edit_editing_finished() assert type(runslider.get_real_value()) == int assert runslider.get_real_value() == 7 # Maximum smaller than current -> update current runslider.max_edit.setText('5') runslider.on_max_edit_editing_finished() assert type(runslider.get_real_value()) == int assert runslider.get_real_value() == 5 # Minimum greater than current -> update current runslider.max_edit.setText('10') runslider.on_max_edit_editing_finished() runslider.min_edit.setText('8') runslider.on_min_edit_editing_finished() assert type(runslider.get_real_value()) == int assert runslider.get_real_value() == 8 def test_range(self, qtbot, scene): runslider = RunSlider() qtbot.addWidget(runslider) node = scene.create_node(scene.registry.create('general_backproject')) node.model['center-position-x'] = [1, 2, 3] assert not runslider.setup(node.model._view._properties['center-position-x'].view_item) assert runslider.view_item is None node.model['center-position-x'] = [1] assert runslider.setup(node.model._view._properties['center-position-x'].view_item) assert runslider.view_item == node.model._view._properties['center-position-x'].view_item assert type(runslider.get_real_value()) == float assert runslider.get_real_value() == 1 runslider.current_edit.setText('1.1') runslider.on_current_edit_editing_finished() assert node.model['center-position-x'] == [runslider.get_real_value()] def test_links(self, qtbot, link_model, nodes): runslider = RunSlider() qtbot.addWidget(runslider) read = nodes['read'] read_2 = nodes['read_2'] runslider.setup(read.model._view._properties['number'].view_item) link_model.add_item(read, read.model, 'number', -1, -1) link_model.add_item(read_2, read_2.model, 'number', 0, -1) runslider.current_edit.setText('123') runslider.on_current_edit_editing_finished() assert read_2.model['number'] == 123 tofu-0.12.0/tofu/tests/test_flow_scene.py000066400000000000000000000414061423713721100204310ustar00rootroot00000000000000import pytest from PyQt5.QtCore import QModelIndex from PyQt5.QtWidgets import QInputDialog from qtpynodeeditor import FlowView from tofu.flow.models import BaseCompositeModel, UfoModelError, UfoReadModel from tofu.flow.util import FlowError, MODEL_ROLE, NODE_ROLE, PROPERTY_ROLE from tofu.tests.flow_util import add_nodes_to_scene class TestScene: def test_create_node(self, qtbot, scene): def check_node(node, gt_caption): # Node must be in the scene assert node in scene.nodes.values() # Caption must be unique assert node.model.caption == gt_caption # Node must be in the nodes model item = scene.node_model.findItems(node.model.caption)[0] assert item.data(role=MODEL_ROLE) == node.model nodes = add_nodes_to_scene(scene, model_names=['read', 'read']) for (node, gt_caption) in zip(nodes, ['Read', 'Read 2']): check_node(node, gt_caption) # Property links must be set up by composites def check_link(model, prop_name): assert scene.property_links_model.find_items((model, prop_name), (MODEL_ROLE, PROPERTY_ROLE)) scene.clear_scene() node = add_nodes_to_scene(scene, model_names=['CFlatFieldCorrect'])[0] for link in node.model._links: model_name, prop_name = link[0] other_name, other_prop_name = link[1] model = node.model[model_name] other = node.model[other_name] model[prop_name] = 0 qtbot.addWidget(model.embedded_widget()) qtbot.addWidget(other.embedded_widget()) qtbot.keyClick(model._view._properties[prop_name].view_item.widget, '1') # Other model's property has to be updated if the property links have been set up # correctly assert node.model[other_name][other_prop_name] == node.model[model_name][prop_name] def test_setstate(self, qtbot, scene): # Make sure there are some links by adding FFC nodes = add_nodes_to_scene(scene, model_names=['CFlatFieldCorrect', 'average']) # Create a connection scene.create_connection(nodes[0]['output'][0], nodes[1]['input'][0]) state = scene.__getstate__() scene.clear_scene() scene.__setstate__(state) assert scene.__getstate__() == state def test_getstate(self, qtbot, scene): # Make sure there are some links by adding FFC nodes = add_nodes_to_scene(scene, model_names=['CFlatFieldCorrect', 'average']) # Create a connection scene.create_connection(nodes[0]['output'][0], nodes[1]['input'][0]) state = scene.__getstate__() # Nodes ids = [record['id'] for record in state['nodes']] assert nodes[0].id in ids assert nodes[1].id in ids # Connections assert len(state['connections']) == 1 conn = state['connections'][0] assert conn['in_id'] == nodes[1].id assert conn['out_id'] == nodes[0].id # Property links assert state['property-links'] == scene.property_links_model.save() def test_restore_node(self, qtbot, monkeypatch, scene): add_nodes_to_scene(scene) old_node = list(scene.nodes.values())[0] state = old_node.__getstate__() scene.remove_node(old_node) new_node = scene.restore_node(state) # Don't test the nodes themselves because the models won't match assert old_node.id == new_node.id # num-inputs monkeypatch.setattr(QInputDialog, 'getInt', lambda *args, **kwargs: (2, True)) node = add_nodes_to_scene(scene, model_names=['retrieve_phase'])[0] state = node.__getstate__() scene.remove_node(node) new_node = scene.restore_node(state) assert new_node.model.num_ports['input'] == 2 def test_remove_node(self, monkeypatch, qtbot, scene, nodes): def cleanup(): self.cleanup_called = True node = add_nodes_to_scene(scene)[0] self.cleanup_called = False node.model.cleanup = cleanup scene.property_links_model.add_item(node, node.model, node.model.properties[0], 0, 0, QModelIndex()) scene.remove_node(list(scene.nodes.values())[0]) # Scene, node model and property links model must be empty assert len(scene.nodes) == 0 assert scene.node_model.rowCount() == 0 assert scene.property_links_model.rowCount() == 0 assert self.cleanup_called # Composite removal monkeypatch.setattr(QInputDialog, "getText", lambda *args: ('cpm', True)) nodes = add_nodes_to_scene(scene, model_names=['pad', 'crop']) for node in nodes: node.graphics_object.setSelected(True) node = scene.create_composite() state = node.__getstate__() scene.remove_node(node) # _composite_nodes must be updated assert scene._composite_nodes == {} # Simulate non-interactive composite creation, i.e. not combining existing nodes into a # composite node. When removing such node, _composite_nodes must not raise a KeyError node = scene.restore_node(state) scene.remove_node(node) def test_is_selected_one_composite(self, qtbot, scene, monkeypatch): # Circumvent the input dialog monkeypatch.setattr(QInputDialog, "getText", lambda *args: ('cpm', True)) nodes = add_nodes_to_scene(scene, model_names=['read', 'read']) for node in nodes: node.graphics_object.setSelected(True) # Simple nodes assert not scene.is_selected_one_composite() node = scene.create_composite() # Composite assert scene.is_selected_one_composite() node.graphics_object.setSelected(False) # Nothing selected assert not scene.is_selected_one_composite() # Composite and other selected add_nodes_to_scene(scene, ['null']) for node in scene.nodes.values(): node.graphics_object.setSelected(True) assert not scene.is_selected_one_composite() def test_skip_nodes(self, qtbot, scene): nodes = add_nodes_to_scene(scene, model_names=['read', 'pad', 'crop', 'null']) read, pad, crop, null = nodes scene.create_connection(read['output'][0], pad['input'][0]) scene.create_connection(pad['output'][0], crop['input'][0]) scene.create_connection(crop['output'][0], null['input'][0]) read.graphics_object.setSelected(True) # Only fully connected nodes can be disabled with pytest.raises(FlowError): scene.skip_nodes() read.graphics_object.setSelected(False) null.graphics_object.setSelected(True) with pytest.raises(FlowError): scene.skip_nodes() null.graphics_object.setSelected(False) pad.graphics_object.setSelected(True) scene.skip_nodes() assert pad.model.skip scene.skip_nodes() assert not pad.model.skip # Deprecation warning coming from imageio @pytest.mark.filterwarnings('ignore::DeprecationWarning') def test_auto_fill(self, qtbot, scene): add_nodes_to_scene(scene) with pytest.raises(UfoModelError): scene.auto_fill() def test_copy_node(self, qtbot, scene): nodes = add_nodes_to_scene(scene, model_names=['read', 'null']) scene.create_connection(nodes[0]['output'][0], nodes[1]['input'][0]) for node in nodes: node.graphics_object.setSelected(True) scene.copy_nodes() assert len(scene.nodes) == 4 # Choose the newly created connections if scene.connections[0].valid_ports['input'].node in nodes: ports = scene.connections[1].valid_ports else: ports = scene.connections[0].valid_ports # The fact that the connections are there means the nodes are there as well, so we don't # need to test that assert ports['input'].node.model.name == 'null' assert ports['output'].node.model.name == 'read' def test_create_composite(self, monkeypatch, qtbot, scene): monkeypatch.setattr(QInputDialog, "getText", lambda *args: ('cpm', True)) plm = scene.property_links_model nodes = add_nodes_to_scene(scene, model_names=['read', 'read']) plm.add_item(nodes[0], nodes[0].model, nodes[0].model.properties[0], -1, -1, QModelIndex()) for (i, node) in enumerate(nodes): node.graphics_object.setSelected(True) node = scene.create_composite() assert node.model._links == [[[nodes[0].model.caption, nodes[0].model.properties[0]]]] with pytest.raises(FlowError): # Can't create a composite with the same name scene.create_composite() assert len(scene.nodes) == 1 assert list(scene.nodes.values())[0] == node assert isinstance(node.model, BaseCompositeModel) assert nodes[0] not in scene.nodes.values() assert nodes[1] not in scene.nodes.values() # Property links model assert plm.item(0, 0).data(role=NODE_ROLE) == node # Simulate non-interactive composite creation, i.e. not combining existing nodes into a # composite node. In this case it can't be possible to create a new composite node with the # same name as has already been registered. state = node.__getstate__() scene.remove_node(node) node = scene.restore_node(state) node.graphics_object.setSelected(True) with pytest.raises(FlowError): scene.create_composite() # Add outer composite with a composite and another simple model inside and set the property # links between the inner composite and inner simple. They must be present in the newly # craeted outer composite. average = add_nodes_to_scene(scene, model_names=['average'])[0] average.graphics_object.setSelected(True) monkeypatch.setattr(QInputDialog, "getText", lambda *args: ('outer', True)) plm.add_item(node, node.model['Read'], 'number', 0, 0) plm.add_item(node, node.model['Read 2'], 'number', 0, 1) plm.add_item(average, average.model, 'number', 0, 2) outer = scene.create_composite() assert len(outer.model._links) == 1 row = outer.model._links[0] assert [node.model.caption, node.model['Read'].caption, 'number'] in row assert [node.model.caption, node.model['Read 2'].caption, 'number'] in row assert [average.model.caption, 'number'] in row assert [node.model.caption, node.model['Read'].caption, 'height'] not in row def test_on_node_double_clicked(self, qtbot, scene, monkeypatch): def double_clicked(*args): self.did_click = True self.did_click = False monkeypatch.setattr(UfoReadModel, "double_clicked", double_clicked) node = add_nodes_to_scene(scene)[0] # We need a view for double clicks _ = FlowView(scene) scene.on_node_double_clicked(node) assert self.did_click def test_expand_composite(self, qtbot, scene, monkeypatch): monkeypatch.setattr(QInputDialog, "getText", lambda *args: ('cpm', True)) plm = scene.property_links_model nodes = add_nodes_to_scene(scene, model_names=['read', 'null']) name_to_caption = {'read': 'Read', 'null': 'Null'} for node in nodes: node.graphics_object.setSelected(True) node = scene.create_composite() path = node.model.get_leaf_paths()[0] plm.add_item(node, path[-1], path[-1].properties[0], -1, -1, QModelIndex()) scene.expand_composite(node) assert plm.item(0, 0).data(role=MODEL_ROLE).name == path[-1].name assert plm.item(0, 0).data(role=NODE_ROLE) in [node for node in scene.selected_nodes()] # Captions are the same for node in scene.nodes.values(): assert node.model.caption == name_to_caption[node.model.name] # New caption if there is a node with the original one # Selection stays, just re-use the expanded nodes monkeypatch.setattr(QInputDialog, "getText", lambda *args: ('cpm_2', True)) node = scene.create_composite() other_read_node = add_nodes_to_scene(scene, model_names=['read'])[0] scene.expand_composite(node) for node in scene.nodes.values(): if node.model.name == 'read': if node == other_read_node: assert node.model.caption == 'Read' else: assert node.model.caption == 'Read 2' def test_is_fully_connected(self, qtbot, scene): nodes = add_nodes_to_scene(scene, model_names=['read', 'pad', 'crop', 'null']) read, pad, crop, null = nodes scene.create_connection(read['output'][0], pad['input'][0]) scene.create_connection(pad['output'][0], crop['input'][0]) scene.create_connection(crop['output'][0], null['input'][0]) assert scene.is_fully_connected() scene.remove_node(read) assert not scene.is_fully_connected() def test_get_simple_node_graphs(self, qtbot, scene, monkeypatch): def connect(read, pad, crop, null): scene.create_connection(read['output'][0], pad['input'][0]) scene.create_connection(pad['output'][0], crop['input'][0]) scene.create_connection(crop['output'][0], null['input'][0]) monkeypatch.setattr(QInputDialog, "getText", lambda *args: ('cpm', True)) nodes = add_nodes_to_scene(scene, model_names=2 * ['read', 'pad', 'crop', 'null']) read, pad, crop, null = nodes[:4] read_2, pad_2, crop_2, null_2 = nodes[4:] connect(read, pad, crop, null) connect(read_2, pad_2, crop_2, null_2) connections = [('Read', 'Pad'), ('Pad', 'Crop'), ('Crop', 'Null'), ('Read 2', 'Pad 2'), ('Pad 2', 'Crop 2'), ('Crop 2', 'Null 2')] graphs = scene.get_simple_node_graphs() assert len(graphs) == 2 num_visited = 0 for graph in graphs: for (src, dst, index) in graph.edges: assert (src.caption, dst.caption) in connections num_visited += 1 assert num_visited == len(connections) # Create first composite for node in nodes: if node.model.name in ['pad', 'crop']: node.graphics_object.setSelected(True) scene.create_composite() # Create a second composite which will cause the scene to have multiple edges between two # nodes (the first composite's outputs and second's inputs) scene.clearSelection() monkeypatch.setattr(QInputDialog, "getText", lambda *args: ('cpm_2', True)) null.graphics_object.setSelected(True) null_2.graphics_object.setSelected(True) scene.create_composite() # Composite must not affect simple graphs, especially the multiple edges cannot be present # anymore graphs = scene.get_simple_node_graphs() assert len(graphs) == 2 num_visited = 0 for graph in graphs: for (src, dst, index) in graph.edges: assert (src.caption, dst.caption) in connections num_visited += 1 assert num_visited == len(connections) add_nodes_to_scene(scene) assert len(scene.get_simple_node_graphs()) == 3 # Test disabling nodes scene.clear_scene() nodes = add_nodes_to_scene(scene, model_names=['read', 'pad', 'crop', 'null']) read, pad, crop, null = nodes connect(read, pad, crop, null) # Disable padding, the generated flow must be read -> crop -> null pad.graphics_object.setSelected(True) scene.skip_nodes() graph = scene.get_simple_node_graphs()[0] assert len(graph.edges) == 2 edges = list(graph.edges) src, dst = edges[0][:-1] assert dst == crop.model src, dst = edges[1][:-1] assert src == crop.model assert dst == null.model def test_set_enabled(self, qtbot, scene): def check(enabled): assert scene.allow_node_creation == enabled assert scene.allow_node_deletion == enabled for node in scene.nodes.values(): assert node._graphics_obj.isEnabled() == enabled for conn in scene.connections: assert conn._graphics_object.isEnabled() == enabled nodes = add_nodes_to_scene(scene, model_names=['CFlatFieldCorrect', 'average']) nodes[0].graphics_object.setSelected(True) # Create a connection scene.create_connection(nodes[0]['output'][0], nodes[1]['input'][0]) scene.set_enabled(False) check(False) scene.set_enabled(True) check(True) assert nodes[0].graphics_object.isSelected() tofu-0.12.0/tofu/tests/test_flow_util.py000066400000000000000000000036671423713721100203200ustar00rootroot00000000000000import pytest from PyQt5.QtWidgets import QInputDialog from tofu.flow.util import CompositeConnection, get_config_key, saved_kwargs from tofu.flow.main import get_filled_registry def test_get_config_key(): # Existing key assert 'z' in get_config_key('models', 'general-backproject', 'hidden-properties') # Non-existent key assert get_config_key('foobarbaz') is None assert get_config_key('foobarbaz', default=1) == 1 def test_saved_kwargs(qtbot, monkeypatch, scene): registry = get_filled_registry() name = 'retrieve_phase' # No num-inputs info monkeypatch.setattr(QInputDialog, 'getInt', lambda *args, **kwargs: (2, True)) state = {'name': name} model = registry.create(name) assert model.num_ports['input'] == 2 # num-inputs specified state = {'name': name, 'num-inputs': 3} with saved_kwargs(registry, state): model = registry.create(name) assert model.num_ports['input'] == 3 assert 'num_inputs' not in registry.registered_model_creators()[state['name']][1] class TestCompositeConnection: def test_init(self): # Identical source and tartet -> exception with pytest.raises(ValueError): CompositeConnection('a', 0, 'a', 0) # OK, must pass CompositeConnection('a', 0, 'b', 0) def test_contains(self): conn = CompositeConnection('a', 0, 'b', 0) assert conn.contains('a', 'output', 0) assert not conn.contains('a', 'output', 1) assert not conn.contains('a', 'input', 0) assert not conn.contains('a', 'input', 1) assert conn.contains('b', 'input', 0) assert not conn.contains('b', 'input', 1) assert not conn.contains('b', 'output', 0) assert not conn.contains('b', 'output', 1) assert not conn.contains('foo', 'input', 14) def test_save(self): conn = CompositeConnection('a', 0, 'b', 0) assert conn.save() == ['a', 0, 'b', 0] tofu-0.12.0/tofu/tests/test_flow_viewer.py000066400000000000000000000272341423713721100206400ustar00rootroot00000000000000import pytest import numpy as np from PyQt5.QtGui import QValidator from tofu.flow.viewer import ImageLabel, ImageViewingError, ScreenImage, ImageViewer @pytest.fixture(scope='function') def screen_image(): image = np.arange(256, dtype=np.float32).reshape(16, 16) return ScreenImage(image=image) @pytest.fixture(scope='function') def viewer(qtbot): viewer = ImageViewer() viewer.images = np.ones((10, 16, 16)) viewer.popup() qtbot.addWidget(viewer._pg_window) return viewer class TestScreenImage: def test_image_setter(self): screen_image = ScreenImage() assert screen_image.image is None screen_image.image = np.random.normal(size=(8, 8)) assert screen_image.minimum is not None assert screen_image.maximum is not None assert screen_image.black_point is not None assert screen_image.white_point is not None def test_black_point_setter(self, screen_image): screen_image.black_point = 100 assert screen_image.black_point == 100 screen_image.white_point = 150 # Black point cannot be greater than white point with pytest.raises(ImageViewingError): screen_image.black_point = 200 def test_white_point_setter(self, screen_image): screen_image.white_point = 100 assert screen_image.white_point == 100 screen_image.black_point = 50 # White point cannot be smaller than black point with pytest.raises(ImageViewingError): screen_image.white_point = 0 def test_reset(self, screen_image): screen_image.reset() # We can test with ==, the data types are the same so the extrema must be exactly the same assert screen_image.minimum == 0 assert screen_image.maximum == 255 assert screen_image.black_point == 0 assert screen_image.white_point == 255 # Going out of the original gray value range must not cause exception on reset screen_image.black_point = -100 screen_image.white_point = -50 screen_image.reset() def test_auto_levels(self, screen_image): screen_image.auto_levels() # nonsense values must pass as well screen_image.auto_levels(percentile=200.0) screen_image.auto_levels(percentile=-200.0) def test_set_black_point_normalized(self, screen_image): screen_image.set_black_point_normalized(100) assert screen_image.black_point == 100 screen_image.set_white_point_normalized(150) # Black point cannot be greater than white point with pytest.raises(ImageViewingError): screen_image.set_black_point_normalized(200) def test_set_white_point_normalized(self, screen_image): screen_image.set_white_point_normalized(100) assert screen_image.white_point == 100 screen_image.set_black_point_normalized(50) # White point cannot be smaller than black point with pytest.raises(ImageViewingError): screen_image.set_white_point_normalized(0) def test_convert_normalized_value_to_native(self, screen_image): assert screen_image.convert_normalized_value_to_native(128) == 128. with pytest.raises(ImageViewingError): screen_image.convert_normalized_value_to_native(-500) with pytest.raises(ImageViewingError): screen_image.convert_normalized_value_to_native(500) def test_convert_native_value_to_normalized(self, screen_image): assert screen_image.convert_native_value_to_normalized(128) == 128. with pytest.raises(ImageViewingError): screen_image.convert_native_value_to_normalized(-500) with pytest.raises(ImageViewingError): screen_image.convert_native_value_to_normalized(500) # One gray value must not cause division by zero erro screen_image.image = np.ones((4, 4)) screen_image.reset() screen_image.convert_native_value_to_normalized(1) def test_get_pixmap(self, qtbot, screen_image): # Empty image must raise an exception with pytest.raises(ImageViewingError): ScreenImage().get_pixmap() # Downsampling pixmap = screen_image.get_pixmap() assert (pixmap.height(), pixmap.width()) == screen_image.image.shape pixmap = screen_image.get_pixmap(downsampling=2) assert (pixmap.height(), pixmap.width()) == tuple(dim // 2 for dim in screen_image.image.shape) # One gray value must not cause division by zero erro screen_image.image = np.ones((4, 4)) screen_image.reset() screen_image.get_pixmap() class TestImageLabel: def test_updateImage(self, qtbot, screen_image): label = ImageLabel() # Empty image must pass label.updateImage() label.screen_image = screen_image label.updateImage() assert label.pixmap() is not None def test_resizeEvent(self, qtbot, screen_image): label = ImageLabel(screen_image) label.updateImage() old_size = label.pixmap().size() # ensure the label will get the resize event label.show() label.resize(8, 8) new_size = label.pixmap().size() assert new_size != old_size class TestImageViewer: def test_images_setter(self, qtbot): viewer = ImageViewer() viewer.images = np.zeros((16, 16)) assert viewer.images.ndim == 3 assert viewer.slider.isHidden() assert float(viewer.min_slider_edit.text()) == 0 assert float(viewer.max_slider_edit.text()) == 0 viewer.images = np.ones((5, 16, 16)) assert viewer.images.ndim == 3 assert not viewer.slider.isHidden() assert viewer.slider.minimum() == 0 assert viewer.slider.maximum() == viewer.images.shape[0] - 1 assert float(viewer.min_slider_edit.text()) == 1 assert float(viewer.max_slider_edit.text()) == 1 # Test viewer and popup window equality viewer.popup() qtbot.addWidget(viewer._pg_window) np.testing.assert_almost_equal(viewer.images, 1) np.testing.assert_almost_equal(viewer._pg_window.image, 1) # 3D viewer.images = np.ones((5, 16, 16)) * 5 np.testing.assert_almost_equal(viewer.images, 5) np.testing.assert_almost_equal(viewer._pg_window.image, 5) # 2D viewer.images = np.ones((16, 16)) * 3 np.testing.assert_almost_equal(viewer.images, 3) np.testing.assert_almost_equal(viewer._pg_window.image, 3) # validators viewer.images = 10 + np.arange(200 * 8 ** 2).reshape(200, 8, 8) validator = viewer.slider_edit.validator() assert validator.validate('199', 0)[0] == QValidator.Acceptable assert validator.validate('2000', 0)[0] == QValidator.Invalid assert viewer.min_slider_edit.validator().bottom() == viewer.images[0].min() assert viewer.min_slider_edit.validator().top() == viewer.images[0].max() viewer._pg_window.close() def test_append(self, qtbot): viewer = ImageViewer() # Append to empty viewer.append(np.zeros((4, 4))) assert viewer.images.ndim == 3 assert viewer.images.shape == (1, 4, 4) # Append 2D viewer.append(np.zeros((4, 4))) assert viewer.images.shape == (2, 4, 4) # Append 3D viewer.append(np.zeros((3, 4, 4))) assert viewer.images.shape == (5, 4, 4) # Append wrong shape with pytest.raises(ImageViewingError): viewer.append(np.zeros((3, 2, 2))) def test_set_enabled_adjustments(self, qtbot): viewer = ImageViewer() def assert_all(value): viewer.set_enabled_adjustments(value) assert viewer.slider.isEnabled() == value assert viewer.slider_edit.isEnabled() == value assert viewer.min_slider.isEnabled() == value assert viewer.min_slider_edit.isEnabled() == value assert viewer.max_slider.isEnabled() == value assert viewer.max_slider_edit.isEnabled() == value assert_all(True) assert_all(False) def test_reset_clim(self, viewer): image = np.arange(16 ** 2).reshape(16, 16) viewer.images = image viewer.append(image * 2) viewer.slider.setValue(1) viewer.reset_clim(auto=False) si = viewer.screen_image min_converted = si.convert_native_value_to_normalized(si.black_point) max_converted = si.convert_native_value_to_normalized(si.white_point) assert viewer.screen_image.maximum == pytest.approx(510) assert viewer.min_slider.value() == pytest.approx(min_converted) assert viewer.max_slider.value() == pytest.approx(max_converted) assert float(viewer.min_slider_edit.text()) == pytest.approx(si.black_point) assert float(viewer.max_slider_edit.text()) == pytest.approx(si.white_point) # Pop up window must be updated assert viewer._pg_window.getLevels() == pytest.approx((si.black_point, si.white_point)) viewer._pg_window.close() def test_on_slider_value_changed(self, viewer): viewer.slider.setValue(5) assert viewer._pg_window.currentIndex == 5 assert viewer.slider_edit.text() == '5' viewer._pg_window.close() def test_on_slider_edit_return_pressed(self, viewer): viewer.slider_edit.setText('5') viewer.slider_edit.returnPressed.emit() assert viewer.slider.value() == 5 assert viewer._pg_window.currentIndex == 5 viewer._pg_window.close() def test_on_min_slider_edit_return_pressed(self, viewer): viewer.images = np.arange(256).reshape(16, 16) viewer.min_slider_edit.setText('100') viewer.min_slider_edit.returnPressed.emit() assert viewer.screen_image.black_point == pytest.approx(100) assert viewer.min_slider.value() == pytest.approx(100) assert viewer._pg_window.getLevels()[0] == pytest.approx(100) viewer._pg_window.close() def test_on_max_slider_edit_return_pressed(self, viewer): viewer.images = np.arange(256).reshape(16, 16) viewer.max_slider_edit.setText('100') viewer.max_slider_edit.returnPressed.emit() assert viewer.screen_image.white_point == pytest.approx(100) assert viewer.max_slider.value() == pytest.approx(100) assert viewer._pg_window.getLevels()[1] == pytest.approx(100) viewer._pg_window.close() def test_on_min_slider_value_changed(self, viewer): viewer.images = np.arange(256).reshape(16, 16) viewer.min_slider.valueChanged.emit(100) assert viewer.screen_image.black_point == pytest.approx(100) assert float(viewer.min_slider_edit.text()) == pytest.approx(100) assert viewer._pg_window.getLevels()[0] == pytest.approx(100) viewer._pg_window.close() def test_on_max_slider_value_changed(self, viewer): viewer.images = np.arange(256).reshape(16, 16) viewer.max_slider.valueChanged.emit(100) assert viewer.screen_image.white_point == pytest.approx(100) assert float(viewer.max_slider_edit.text()) == pytest.approx(100) assert viewer._pg_window.getLevels()[1] == pytest.approx(100) viewer._pg_window.close() def test_popup(self, qtbot, viewer): # Close and another popup call must show the widget viewer._pg_window.close() viewer.popup() assert viewer._pg_window.isVisible() # 2D must work other = ImageViewer() other.images = np.ones((4, 4)) other.popup() qtbot.addWidget(other._pg_window) assert other._pg_window is not None viewer._pg_window.close() other._pg_window.close() tofu-0.12.0/tofu/util.py000066400000000000000000000246061423713721100150640ustar00rootroot00000000000000"""Various utility functions.""" import argparse import glob import logging import math import os from collections import OrderedDict LOG = logging.getLogger(__name__) def range_list(value): """ Split *value* separated by ':' into int triple, filling missing values with 1s. """ def check(region): if region[0] >= region[1]: raise argparse.ArgumentTypeError("{} must be less than {}".format(region[0], region[1])) lst = [int(x) for x in value.split(':')] if len(lst) == 1: frm = lst[0] return (frm, frm + 1, 1) if len(lst) == 2: check(lst) return (lst[0], lst[1], 1) if len(lst) == 3: check(lst) return (lst[0], lst[1], lst[2]) raise argparse.ArgumentTypeError("Cannot parse {}".format(value)) def make_subargs(args, subargs): """Return an argparse.Namespace consisting of arguments from *args* which are listed in the *subargs* list.""" namespace = argparse.Namespace() for subarg in subargs: setattr(namespace, subarg, getattr(args, subarg)) return namespace def set_node_props(node, args): """Set up *node*'s properties to *args* which is a dictionary of values.""" for name in dir(node.props): if not name.startswith('_') and hasattr(args, name): value = getattr(args, name) if value is not None: LOG.debug("Setting {}:{} to {}".format(node.get_plugin_name(), name, value)) node.set_property(name, getattr(args, name)) def get_filenames(path): """ Get all filenams from *path*, which could be a directory or a pattern for matching files in a directory. """ return sorted(glob.glob(os.path.join(path, '*') if os.path.isdir(path) else path)) def setup_read_task(task, path, args): """Set up *task* and take care of handling file types correctly.""" task.props.path = path fnames = get_filenames(path) if fnames and fnames[0].endswith('.raw'): if not args.width or not args.height: raise RuntimeError("Raw files require --width, --height and --bitdepth arguments.") task.props.raw_width = args.width task.props.raw_height = args.height task.props.raw_bitdepth = args.bitdepth def restrict_value(limits, dtype=float): """Convert value to *dtype* and make sure it is within *limits* (included) specified as tuple (min, max). If one of the tuple values is None it is ignored.""" def check(value): result = dtype(value) if limits[0] is not None and result < limits[0]: raise argparse.ArgumentTypeError('Value cannot be less than {}'.format(limits[0])) if limits[1] is not None and result > limits[1]: raise argparse.ArgumentTypeError('Value cannot be greater than {}'.format(limits[1])) return result return check def convert_filesize(value): multiplier = 1 conv = OrderedDict((('k', 2 ** 10), ('m', 2 ** 20), ('g', 2 ** 30), ('t', 2 ** 40))) if not value[-1].isdigit(): if value[-1] not in list(conv.keys()): raise argparse.ArgumentTypeError('--output-bytes-per-file must either be a ' + 'number or end with {} '.format(list(conv.keys())) + 'to indicate kilo, mega, giga or terabytes') multiplier = conv[value[-1]] value = value[:-1] value = int(float(value) * multiplier) if value < 0: raise argparse.ArgumentTypeError('--output-bytes-per-file cannot be less than zero') return value def tupleize(num_items=None, conv=float, dtype=tuple): """Convert comma-separated string values to a *num-items*-tuple of values converted with *conv*. """ def split_values(value): """Convert comma-separated string *value* to a tuple of numbers.""" try: result = dtype([conv(x) for x in value.split(',')]) except: raise argparse.ArgumentTypeError('Expect comma-separated tuple') if num_items and len(result) != num_items: raise argparse.ArgumentTypeError('Expected {} items'.format(num_items)) return result return split_values def next_power_of_two(number): """Compute the next power of two of the *number*.""" return 2 ** int(math.ceil(math.log(number, 2))) def read_image(filename): """Read image from file *filename*.""" if filename.lower().endswith('.tif') or filename.lower().endswith('.tiff'): from tifffile import TiffFile import numpy as np with TiffFile(filename) as tif: return tif.asarray(out='memmap') elif '.edf' in filename.lower(): import fabio edf = fabio.edfimage.edfimage() edf.read(filename) return edf.data else: raise ValueError('Unsupported image format') def get_image_shape(filename): """Determine image shape (numpy order) from file *filename*.""" if filename.lower().endswith('.tif') or filename.lower().endswith('.tiff'): from tifffile import TiffFile with TiffFile(filename) as tif: page = tif.pages[0] shape = (page.imagelength, page.imagewidth) if len(tif.pages) > 1: shape = (len(tif.pages),) + shape else: # fabio doesn't seem to be able to read the shape without reading the data shape = read_image(filename).shape return shape def get_first_filename(path): """Returns the first valid image filename in *path*.""" if not path: raise RuntimeError("Path to sinograms or projections not set.") filenames = get_filenames(path) if not filenames: raise RuntimeError("No files found in `{}'".format(path)) return filenames[0] def determine_shape(args, path=None, store=False): """Determine input shape from *args* which means either width and height are specified in args or try to read the *path* and determine the shape from it. The default path is args.projections, which is the typical place to find the input. If *store* is True, assign the determined values if they aren't already present in *args*. Return a tuple (width, height). """ width = args.width height = args.height if not (width and height): filename = get_first_filename(path or args.projections) try: shape = get_image_shape(filename) # Now set the width and height if not specified width = width or shape[-1] height = height or shape[-2] except: LOG.info("Couldn't determine image dimensions from '{}'".format(filename)) if store: if not args.width: args.width = width if not args.height: args.height = height - args.y return (width, height) def get_filtering_padding(width): """Get the number of horizontal padded pixels in order to avoid convolution artifacts.""" return next_power_of_two(2 * width) - width def setup_padding(pad, width, height, mode, crop=None, pad_width=0, pad_height=0): if not pad_width: # Default is horizontal padding only pad_width = get_filtering_padding(width) pad.props.width = width + pad_width pad.props.height = height + pad_height pad.props.x = pad_width // 2 pad.props.y = pad_height // 2 pad.props.addressing_mode = mode LOG.debug('Padded size: ({}, {})'.format(width + pad_width, height + pad_height)) LOG.debug('Padding mode: {}'.format(mode)) if crop: # crop to original width after filtering crop.props.width = width crop.props.height = height crop.props.x = pad_width // 2 crop.props.y = pad_height // 2 return (pad_width, pad_height) def make_region(n, dtype=int): """Make region in such a way that in case of odd *n* it is centered around 0. Use *dtype* as data type. """ return (-dtype(n / 2), dtype(n / 2 + n % 2), dtype(1)) def get_reconstructed_cube_shape(x_region, y_region, z_region): """Get the shape of the reconstructed cube as (slice width, slice height, num slices).""" import numpy as np z_start, z_stop, z_step = z_region y_start, y_stop, y_step = y_region x_start, x_stop, x_step = x_region num_slices = len(np.arange(z_start, z_stop, z_step)) slice_height = len(np.arange(y_start, y_stop, y_step)) slice_width = len(np.arange(x_start, x_stop, x_step)) return slice_width, slice_height, num_slices def get_reconstruction_regions(params, store=False, dtype=int): """Compute reconstruction regions along all three axes, use *dtype* for as data type for x and y regions, z region is always float. """ width, height = determine_shape(params) if getattr(params, 'transpose_input', False): # In case down the pipeline there is a transpose task tmp = width width = height height = tmp if params.x_region[1] == -1: x_region = make_region(width, dtype=dtype) else: x_region = params.x_region if params.y_region[1] == -1: y_region = make_region(width, dtype=dtype) else: y_region = params.y_region if params.region[1] == -1: region = make_region(height, dtype=float) else: region = params.region LOG.info('X region: {}'.format(x_region)) LOG.info('Y region: {}'.format(y_region)) LOG.info('Parameter region: {}'.format(region)) if store: params.x_region = x_region params.y_region = y_region params.region = region return x_region, y_region, region def get_scarray_value(scarray, index): if len(scarray) == 1: return scarray[0] return scarray[index] class Vector(object): """A vector based on axis-angle representation.""" def __init__(self, x_angle=0, y_angle=0, z_angle=0, position=None): import numpy as np self.position = np.array(position, dtype=np.float) if position is not None else None self.x_angle = x_angle self.y_angle = y_angle self.z_angle = z_angle def __repr__(self): return 'Vector(position={}, angles=({}, {}, {}))'.format(self.position, self.x_angle, self.y_angle, self.z_angle) def __str__(self): return repr(self) tofu-0.12.0/tofu/vis/000077500000000000000000000000001423713721100143265ustar00rootroot00000000000000tofu-0.12.0/tofu/vis/__init__.py000066400000000000000000000000001423713721100164250ustar00rootroot00000000000000tofu-0.12.0/tofu/vis/qt.py000066400000000000000000000132501423713721100153250ustar00rootroot00000000000000import pyqtgraph as pg import pyqtgraph.opengl as gl import logging import numpy as np import tifffile from PyQt4 import QtGui, QtCore LOG = logging.getLogger(__name__) def read_tiff(filename): tiff = tifffile.TiffFile(filename) array = tiff.asarray() return array.T def remove_extrema(data): upper = np.percentile(data, 99) lower = np.percentile(data, 1) data[data > upper] = upper data[data < lower] = lower return data def create_volume(data): gradient = (data - np.roll(data, 1))**2 cmin = gradient.min() div = gradient.max() - cmin gradient = (gradient - cmin) / div * 255 volume = np.empty(data.shape + (4, ), dtype=np.ubyte) volume[..., 0] = data volume[..., 1] = data volume[..., 2] = data volume[..., 3] = gradient return volume class ImageViewer(QtGui.QWidget): """ Present a sequence of files that can be browsed with a slider. To get the currently selected position connect to the *slider* attribute's valueChanged signal. """ def __init__(self, filenames, parent=None): super(ImageViewer, self).__init__(parent) image_view = pg.ImageView() image_view.getView().setAspectLocked(True) self.image_item = image_view.getImageItem() self.slider = QtGui.QSlider(QtCore.Qt.Horizontal) self.slider.valueChanged.connect(self.update_image) self.main_layout = QtGui.QVBoxLayout(self) self.main_layout.addWidget(image_view) self.main_layout.addWidget(self.slider) self.setLayout(self.main_layout) self.load_files(filenames) def load_files(self, filenames): """Load *filenames* for display.""" self.filenames = filenames self.slider.setRange(0, len(self.filenames) - 1) self.slider.setSliderPosition(0) self.update_image() def update_image(self): """Update the currently display image.""" if self.filenames: pos = self.slider.value() image = read_tiff(self.filenames[pos]) self.image_item.setImage(image) class ImageWindow(object): """ Stand-alone window to display image sequences. """ global_app = None def __init__(self, filenames): self.global_app = QtGui.QApplication.instance() or QtGui.QApplication([]) self.viewer = ImageViewer(filenames) self.viewer.show() class OverlapViewer(QtGui.QWidget): """ Presents two images by subtracting the flipped second from the first. To get the current deviation connect to the *slider* attribute's valueChanged signal. """ def __init__(self, parent=None, remove_extrema=False): super(OverlapViewer, self).__init__() image_view = pg.ImageView() image_view.getView().setAspectLocked(True) self.image_item = image_view.getImageItem() self.slider = QtGui.QSlider(QtCore.Qt.Horizontal) self.slider.setRange(0, 0) self.slider.valueChanged.connect(self.update_image) self.main_layout = QtGui.QVBoxLayout() self.main_layout.addWidget(image_view) self.main_layout.addWidget(self.slider) self.setLayout(self.main_layout) self.first, self.second = (None, None) self.remove_extrema = remove_extrema self.subtract = True def set_images(self, first, second): """Set *first* and *second* image.""" self.first, self.second = first.T, np.flipud(second.T) if self.remove_extrema: self.first = remove_extrema(self.first) self.second = remove_extrema(self.second) if self.first.shape != self.second.shape: LOG.warn("Shape {} of {} is different to {} of {}". format(self.first.shape, self.first, self.second.shape, self.second)) width = self.first.shape[0] self.slider.setRange(-width / 2, int(1.5 * width)) self.slider.setSliderPosition(self.first.shape[0] / 2) self.update_image() def set_position(self, position): self.slider.setValue(int(position)) self.update_image() def update_image(self): """Update the current subtraction.""" if self.first is None or self.second is None: LOG.warn("No images set yet") else: pos = self.slider.value() moved = np.roll(self.second, self.second.shape[0] // 2 - pos, axis=0) if self.subtract: self.image_item.setImage(moved - self.first) else: self.image_item.setImage(moved + self.first) class VolumeViewer(QtGui.QWidget): def __init__(self, step=1, density=1, parent=None): super(VolumeViewer, self).__init__(parent) self.volume_view = gl.GLViewWidget() self.main_layout = QtGui.QVBoxLayout() self.main_layout.addWidget(self.volume_view) self.setLayout(self.main_layout) self.step = step self.density = density def load_files(self, filenames): """Load *filenames* for display.""" filenames = filenames[::self.step] num = len(filenames) first = read_tiff(filenames[0])[::self.step, ::self.step] width, height = first.shape data = np.empty((width, height, num), dtype=np.float32) data[:,:,0] = first for i, filename in enumerate(filenames[1:]): data[:, :, i + 1] = read_tiff(filename)[::self.step, ::self.step] volume = create_volume(data) dx, dy, dz, _ = volume.shape volume_item = gl.GLVolumeItem(volume, sliceDensity=self.density) volume_item.translate(-dx / 2, -dy / 2, -dz / 2) volume_item.scale(0.05, 0.05, 0.05, local=False) self.volume_view.addItem(volume_item) tofu-0.12.0/tox.ini000066400000000000000000000000771423713721100140670ustar00rootroot00000000000000[flake8] ignore = E402, E721, E722, W503 max-line-length = 100