pax_global_header 0000666 0000000 0000000 00000000064 14237137211 0014513 g ustar 00root root 0000000 0000000 52 comment=048d50c8725305567469eafb0de3e07a82e65b59
tofu-0.12.0/ 0000775 0000000 0000000 00000000000 14237137211 0012550 5 ustar 00root root 0000000 0000000 tofu-0.12.0/.gitignore 0000664 0000000 0000000 00000000074 14237137211 0014541 0 ustar 00root root 0000000 0000000 *.pyc
build/
dist/
*.egg-info/
install_manifest*.txt
.idea/
tofu-0.12.0/LICENSE 0000664 0000000 0000000 00000016743 14237137211 0013570 0 ustar 00root root 0000000 0000000 GNU LESSER GENERAL PUBLIC LICENSE
Version 3, 29 June 2007
Copyright (C) 2007 Free Software Foundation, Inc.
Everyone is permitted to copy and distribute verbatim copies
of this license document, but changing it is not allowed.
This version of the GNU Lesser General Public License incorporates
the terms and conditions of version 3 of the GNU General Public
License, supplemented by the additional permissions listed below.
0. Additional Definitions.
As used herein, "this License" refers to version 3 of the GNU Lesser
General Public License, and the "GNU GPL" refers to version 3 of the GNU
General Public License.
"The Library" refers to a covered work governed by this License,
other than an Application or a Combined Work as defined below.
An "Application" is any work that makes use of an interface provided
by the Library, but which is not otherwise based on the Library.
Defining a subclass of a class defined by the Library is deemed a mode
of using an interface provided by the Library.
A "Combined Work" is a work produced by combining or linking an
Application with the Library. The particular version of the Library
with which the Combined Work was made is also called the "Linked
Version".
The "Minimal Corresponding Source" for a Combined Work means the
Corresponding Source for the Combined Work, excluding any source code
for portions of the Combined Work that, considered in isolation, are
based on the Application, and not on the Linked Version.
The "Corresponding Application Code" for a Combined Work means the
object code and/or source code for the Application, including any data
and utility programs needed for reproducing the Combined Work from the
Application, but excluding the System Libraries of the Combined Work.
1. Exception to Section 3 of the GNU GPL.
You may convey a covered work under sections 3 and 4 of this License
without being bound by section 3 of the GNU GPL.
2. Conveying Modified Versions.
If you modify a copy of the Library, and, in your modifications, a
facility refers to a function or data to be supplied by an Application
that uses the facility (other than as an argument passed when the
facility is invoked), then you may convey a copy of the modified
version:
a) under this License, provided that you make a good faith effort to
ensure that, in the event an Application does not supply the
function or data, the facility still operates, and performs
whatever part of its purpose remains meaningful, or
b) under the GNU GPL, with none of the additional permissions of
this License applicable to that copy.
3. Object Code Incorporating Material from Library Header Files.
The object code form of an Application may incorporate material from
a header file that is part of the Library. You may convey such object
code under terms of your choice, provided that, if the incorporated
material is not limited to numerical parameters, data structure
layouts and accessors, or small macros, inline functions and templates
(ten or fewer lines in length), you do both of the following:
a) Give prominent notice with each copy of the object code that the
Library is used in it and that the Library and its use are
covered by this License.
b) Accompany the object code with a copy of the GNU GPL and this license
document.
4. Combined Works.
You may convey a Combined Work under terms of your choice that,
taken together, effectively do not restrict modification of the
portions of the Library contained in the Combined Work and reverse
engineering for debugging such modifications, if you also do each of
the following:
a) Give prominent notice with each copy of the Combined Work that
the Library is used in it and that the Library and its use are
covered by this License.
b) Accompany the Combined Work with a copy of the GNU GPL and this license
document.
c) For a Combined Work that displays copyright notices during
execution, include the copyright notice for the Library among
these notices, as well as a reference directing the user to the
copies of the GNU GPL and this license document.
d) Do one of the following:
0) Convey the Minimal Corresponding Source under the terms of this
License, and the Corresponding Application Code in a form
suitable for, and under terms that permit, the user to
recombine or relink the Application with a modified version of
the Linked Version to produce a modified Combined Work, in the
manner specified by section 6 of the GNU GPL for conveying
Corresponding Source.
1) Use a suitable shared library mechanism for linking with the
Library. A suitable mechanism is one that (a) uses at run time
a copy of the Library already present on the user's computer
system, and (b) will operate properly with a modified version
of the Library that is interface-compatible with the Linked
Version.
e) Provide Installation Information, but only if you would otherwise
be required to provide such information under section 6 of the
GNU GPL, and only to the extent that such information is
necessary to install and execute a modified version of the
Combined Work produced by recombining or relinking the
Application with a modified version of the Linked Version. (If
you use option 4d0, the Installation Information must accompany
the Minimal Corresponding Source and Corresponding Application
Code. If you use option 4d1, you must provide the Installation
Information in the manner specified by section 6 of the GNU GPL
for conveying Corresponding Source.)
5. Combined Libraries.
You may place library facilities that are a work based on the
Library side by side in a single library together with other library
facilities that are not Applications and are not covered by this
License, and convey such a combined library under terms of your
choice, if you do both of the following:
a) Accompany the combined library with a copy of the same work based
on the Library, uncombined with any other library facilities,
conveyed under the terms of this License.
b) Give prominent notice with the combined library that part of it
is a work based on the Library, and explaining where to find the
accompanying uncombined form of the same work.
6. Revised Versions of the GNU Lesser General Public License.
The Free Software Foundation may publish revised and/or new versions
of the GNU Lesser General Public License from time to time. Such new
versions will be similar in spirit to the present version, but may
differ in detail to address new problems or concerns.
Each version is given a distinguishing version number. If the
Library as you received it specifies that a certain numbered version
of the GNU Lesser General Public License "or any later version"
applies to it, you have the option of following the terms and
conditions either of that published version or of any later version
published by the Free Software Foundation. If the Library as you
received it does not specify a version number of the GNU Lesser
General Public License, you may choose any version of the GNU Lesser
General Public License ever published by the Free Software Foundation.
If the Library as you received it specifies that a proxy can decide
whether future versions of the GNU Lesser General Public License shall
apply, that proxy's public statement of acceptance of any version is
permanent authorization for you to choose that version for the
Library.
tofu-0.12.0/MANIFEST.in 0000664 0000000 0000000 00000000047 14237137211 0014307 0 ustar 00root root 0000000 0000000 include pkgconfig.py
include README.md
tofu-0.12.0/README.md 0000664 0000000 0000000 00000006517 14237137211 0014040 0 ustar 00root root 0000000 0000000 ## About
This repository contains Python data processing scripts to be used with the UFO
framework. At the moment they are targeted at high-performance reconstruction of
tomographic data sets.
## Installation
Run
pip install -r requirements-guis.txt # If you want to use flow or ez
python setup.py install
in a prepared virtualenv or as root for system-wide installation. Note, that if
you do plan to use the graphical user interface you need PyQt4, pyqtgraph and
PyOpenGL. You are strongly advised to install PyQt through your system package
manager, you can install pyqtgraph and PyOpenGL using the pip package manager
though:
pip install pyqtgraph PyOpenGL
## Usage
### Flow
`tofu flow` is a visual flow programming tool. You can create a flow by using any task from [ufo-filters](https://github.com/ufo-kit/ufo-filters) and execute it. In includes visualization of 2D and 3D results, so you can quickly check the output of your flow, which is useful for finding algorithm parameters.

### Reconstruction
To do a tomographic reconstruction you simply call
$ tofu tomo --sinograms $PATH_TO_SINOGRAMS
from the command line. To get get correct results, you may need to append
options such as `--axis-pos/-a` and `--angle-step/-a` (which are given in
radians!). Input paths are either directories or glob patterns. Output paths are
either directories or a format that contains one `%i`
[specifier](http://www.pixelbeat.org/programming/gcc/format_specs.html):
$ tofu tomo --axis-pos=123.4 --angle-step=0.000123 \
--sinograms="/foo/bar/*.tif" --output="/output/slices-%05i.tif"
You can get a help for all options by running
$ tofu tomo --help
and more verbose output by running with the `-v/--verbose` flag.
You can also load reconstruction parameters from a configuration file called
`reco.conf`. You may create a template with
$ tofu init
Note, that options passed via the command line always override configuration
parameters!
Besides scripted reconstructions, one can also run a standalone GUI for both
reconstruction and quick assessment of the reconstructed data via
$ tofu gui

### Performance measurement
If you are running at least ufo-core/filters 0.6, you can evaluate the performance
of the filtered backprojection (without sinogram transposition!), with
$ tofu perf
You can customize parameter scans, pretty easily via
$ tofu perf --width 256:8192:256 --height 512
which will reconstruct all combinations of width between 256 and 8192 with a
step of 256 and a fixed height of 512 pixels.
### Estimating the center of rotation
If you do not know the correct center of rotation from your experimental setup,
you can estimate it with:
$ tofu estimate -i $PATH_TO_SINOGRAMS
Currently, a modified algorithm based on the work of [Donath et
al.](http://dx.doi.org/10.1364/JOSAA.23.001048) is used to determine the center.
## Citation
If you use this software for publishing your data, we kindly ask to cite the article below.
Faragó, T., Gasilov, S., Emslie, I., Zuber, M., Helfen, L., Vogelgesang, M. & Baumbach, T. (2022). J. Synchrotron Rad.
29, https://doi.org/10.1107/S160057752200282X
tofu-0.12.0/bin/ 0000775 0000000 0000000 00000000000 14237137211 0013320 5 ustar 00root root 0000000 0000000 tofu-0.12.0/bin/tofu 0000775 0000000 0000000 00000017263 14237137211 0014234 0 ustar 00root root 0000000 0000000 #!/usr/bin/env python3
import os
import sys
import argparse
import logging
import time
import re
import gi
from tofu import config, __version__
gi.require_version('Ufo', '0.0')
LOG = logging.getLogger('tofu')
def init(args):
if not os.path.exists(args.config):
config.write(args.config)
else:
raise RuntimeError("{0} already exists".format(args.config))
def run_tomo(args):
from tofu import reco
reco.tomo(args)
def run_lamino(args):
from tofu import lamino
lamino.lamino(args)
def run_genreco(args):
from tofu import genreco
genreco.genreco(args)
def run_flat_correct(args):
from tofu import preprocess
preprocess.run_flat_correct(args)
def run_preprocessing(args):
from tofu import preprocess
preprocess.run_preprocessing(args)
def run_sinos(args):
from tofu import preprocess
preprocess.run_sinogram_generation(args)
def run_ez(args):
from tofu.ez.GUI.ezufo_launcher import main_qt
main_qt(args)
def get_ipython_shell(config=None):
import IPython
version = IPython.__version__
shell = None
def cmp_versions(v1, v2):
"""Compare two version numbers and return cmp compatible result"""
def normalize(v):
return [int(x) for x in re.sub(r'(\.0+)*$', '', v).split(".")]
n1 = normalize(v1)
n2 = normalize(v2)
return (n1 > n2) - (n1 < n2)
if cmp_versions(version, '0.11') < 0:
from IPython.Shell import IPShellEmbed
shell = IPShellEmbed()
elif cmp_versions(version, '1.0') < 0:
from IPython.frontend.terminal.embed import \
InteractiveShellEmbed
shell = InteractiveShellEmbed(config=config, banner1='')
else:
from IPython.terminal.embed import InteractiveShellEmbed
shell = InteractiveShellEmbed(config=config, banner1='')
return shell
def run_shell(args):
from tofu import reco
shell = get_ipython_shell()
shell()
def run_find_large_spots(args):
from tofu.find_large_spots import find_large_spots
find_large_spots(args)
def gui(args):
try:
from tofu import gui
gui.main(args)
except ImportError as e:
LOG.error(str(e))
def run_flow(args):
from tofu.flow.main import main as flow_main
flow_main()
def estimate(params):
from tofu import reco
center = reco.estimate_center(params)
if params.verbose:
out = '>>> Best axis of rotation: {}'.format(center)
else:
out = center
print(out)
def perf(args):
from tofu import reco
def measure(args):
exec_times = []
total_times = []
for i in range(args.num_runs):
start = time.time()
exec_times.append(reco.tomo(args))
total_times.append(time.time() - start)
exec_time = sum(exec_times) / len(exec_times)
total_time = sum(total_times) / len(total_times)
overhead = (total_time / exec_time - 1.0) * 100
input_bandwidth = args.width * args.height * num_projections * 4 / exec_time / 1024. / 1024.
output_bandwidth = args.width * args.width * height * 4 / exec_time / 1024. / 1024.
slice_bandwidth = args.height / exec_time
# Four bytes of our output bandwidth constitute one slice pixel, for each
# pixel we have to do roughly n * 6 floating point ops (2 mad, 1 add, 1
# interpolation)
flops = output_bandwidth / 4 * 6 * num_projections / 1024
msg = ("width={:<6d} height={:<6d} n_proj={:<6d} "
"exec={:.4f}s total={:.4f}s overhead={:.2f}% "
"bandwidth_i={:.2f}MB/s bandwidth_o={:.2f}MB/s slices={:.2f}/s "
"flops={:.2f}GFLOPs\n")
sys.stdout.write(msg.format(args.width, args.height, args.number,
exec_time, total_time, overhead,
input_bandwidth, output_bandwidth, slice_bandwidth, flops))
sys.stdout.flush()
args.projections = None
args.sinograms = None
args.dry_run = True
for width in range(*args.width_range):
for height in range(*args.height_range):
for num_projections in range(*args.num_projection_range):
args.width = width
args.height = height
args.number = num_projections
measure(args)
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--config', **config.SECTIONS['general']['config'])
parser.add_argument('--version', action='version',
version='%(prog)s {}'.format(__version__))
sino_params = ('flat-correction', 'sinos')
reco_params = ('flat-correction', 'reconstruction')
tomo_params = config.TOMO_PARAMS
lamino_params = config.LAMINO_PARAMS
gui_params = tomo_params + ('gui', )
cmd_parsers = [
('init', init, (), "Create configuration file"),
('preprocess', run_preprocessing, config.PREPROC_PARAMS, "Run preprocessing"),
('flatcorrect', run_flat_correct, ('flat-correction',), "Run flat field correction"),
('sinos', run_sinos, sino_params, "Generate sinograms from projections"),
('tomo', run_tomo, tomo_params, "Run tomographic reconstruction"),
('lamino', run_lamino, lamino_params, "Run laminographic reconstruction"),
('reco', run_genreco, config.GEN_RECO_PARAMS, "Run general projection-based "
"reconstruction for tomographic/"
"laminographic cone/parallel beam"),
('gui', gui, tomo_params + ('gui',), "GUI for tomographic reconstruction"),
('flow', run_flow, (), "Visual flow creation"),
('ez', run_ez, (), "GUI for making ufo-kit data processing pipelines"),
('estimate', estimate, tomo_params + ('estimate',), "Estimate center of rotation"),
('perf', perf, tomo_params + ('perf',), "Check reconstruction performance"),
('interactive', run_shell, tomo_params, "Run interactive mode"),
('find-large-spots', run_find_large_spots, ('find-large-spots',), "Find large spots on images"),
]
if sys.version < '3.7':
subparsers = parser.add_subparsers(title="Commands", dest='commands')
else:
subparsers = parser.add_subparsers(title="Commands", dest='commands', required=True)
for cmd, func, sections, text in cmd_parsers:
cmd_params = config.Params(sections=sections)
cmd_parser = subparsers.add_parser(cmd, help=text, formatter_class=argparse.ArgumentDefaultsHelpFormatter)
cmd_parser = cmd_params.add_arguments(cmd_parser)
cmd_parser.set_defaults(_func=func)
args = config.parse_known_args(parser, subparser=True)
log_level = logging.DEBUG if args.verbose else logging.INFO
LOG.setLevel(log_level)
stream_handler = logging.StreamHandler(sys.stdout)
stream_handler.setFormatter(logging.Formatter('%(levelname)s: %(message)s'))
LOG.addHandler(stream_handler)
if args.log:
file_handler = logging.FileHandler(args.log)
file_handler.setFormatter(logging.Formatter('[%(asctime)s] %(name)s:%(levelname)s: %(message)s'))
LOG.addHandler(file_handler)
try:
config.log_values(args)
args._func(args)
except RuntimeError as e:
LOG.error(str(e))
sys.exit(1)
if __name__ == '__main__':
main()
# vim: ft=python
tofu-0.12.0/docs/ 0000775 0000000 0000000 00000000000 14237137211 0013500 5 ustar 00root root 0000000 0000000 tofu-0.12.0/docs/Makefile 0000664 0000000 0000000 00000001110 14237137211 0015131 0 ustar 00root root 0000000 0000000 # Minimal makefile for Sphinx documentation
#
# You can set these variables from the command line.
SPHINXOPTS =
SPHINXBUILD = sphinx-build
SOURCEDIR = source
BUILDDIR = build
# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) tofu-0.12.0/docs/source/ 0000775 0000000 0000000 00000000000 14237137211 0015000 5 ustar 00root root 0000000 0000000 tofu-0.12.0/docs/source/api/ 0000775 0000000 0000000 00000000000 14237137211 0015551 5 ustar 00root root 0000000 0000000 tofu-0.12.0/docs/source/api/genreco.rst 0000664 0000000 0000000 00000000121 14237137211 0017717 0 ustar 00root root 0000000 0000000 3D Reconstruction
=================
.. automodule:: tofu.genreco
:members:
tofu-0.12.0/docs/source/api/preprocessing.rst 0000664 0000000 0000000 00000000115 14237137211 0021163 0 ustar 00root root 0000000 0000000 Pre-processing
==============
.. automodule:: tofu.preprocess
:members:
tofu-0.12.0/docs/source/api/util.rst 0000664 0000000 0000000 00000000075 14237137211 0017262 0 ustar 00root root 0000000 0000000 Utilities
=========
.. automodule:: tofu.util
:members:
tofu-0.12.0/docs/source/conf.py 0000664 0000000 0000000 00000013315 14237137211 0016302 0 ustar 00root root 0000000 0000000 # -*- coding: utf-8 -*-
#
# Configuration file for the Sphinx documentation builder.
#
# This file does only contain a selection of the most common options. For a
# full list see the documentation:
# http://www.sphinx-doc.org/en/master/config
# -- Path setup --------------------------------------------------------------
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
import os
import sys
sys.path.insert(0, os.path.abspath(os.path.join('..', '..')))
# -- Project information -----------------------------------------------------
project = 'Tofu'
copyright = '2020, Tomas Farago'
author = 'Tomas Farago'
# The short X.Y version
version = ''
# The full version, including alpha/beta/rc tags
release = ''
# -- General configuration ---------------------------------------------------
# If your documentation needs a minimal Sphinx version, state it here.
#
# needs_sphinx = '1.0'
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.doctest',
'sphinx.ext.intersphinx',
'sphinx.ext.todo',
'sphinx.ext.coverage',
'sphinx.ext.imgmath',
'sphinx.ext.ifconfig',
'sphinx.ext.viewcode',
'sphinx.ext.githubpages',
]
autodoc_mock_imports = ['gi']
# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
# The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string:
#
# source_suffix = ['.rst', '.md']
source_suffix = '.rst'
# The master toctree document.
master_doc = 'index'
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
#
# This is also used if you do content translation via gettext catalogs.
# Usually you set "language" from the command line for these cases.
language = None
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = []
# The name of the Pygments (syntax highlighting) style to use.
pygments_style = None
# -- Options for HTML output -------------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
html_theme = 'alabaster'
# Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the
# documentation.
#
# html_theme_options = {}
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['_static']
# Custom sidebar templates, must be a dictionary that maps document names
# to template names.
#
# The default sidebars (for documents that don't match any pattern) are
# defined by theme itself. Builtin themes are using these templates by
# default: ``['localtoc.html', 'relations.html', 'sourcelink.html',
# 'searchbox.html']``.
#
# html_sidebars = {}
# -- Options for HTMLHelp output ---------------------------------------------
# Output file base name for HTML help builder.
htmlhelp_basename = 'Tofudoc'
# -- Options for LaTeX output ------------------------------------------------
latex_elements = {
# The paper size ('letterpaper' or 'a4paper').
#
# 'papersize': 'letterpaper',
# The font size ('10pt', '11pt' or '12pt').
#
# 'pointsize': '10pt',
# Additional stuff for the LaTeX preamble.
#
# 'preamble': '',
# Latex figure (float) alignment
#
# 'figure_align': 'htbp',
}
# Grouping the document tree into LaTeX files. List of tuples
# (source start file, target name, title,
# author, documentclass [howto, manual, or own class]).
latex_documents = [
(master_doc, 'Tofu.tex', 'Tofu Documentation',
'Tomas Farago', 'manual'),
]
# -- Options for manual page output ------------------------------------------
# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
man_pages = [
(master_doc, 'tofu', 'Tofu Documentation',
[author], 1)
]
# -- Options for Texinfo output ----------------------------------------------
# Grouping the document tree into Texinfo files. List of tuples
# (source start file, target name, title, author,
# dir menu entry, description, category)
texinfo_documents = [
(master_doc, 'Tofu', 'Tofu Documentation',
author, 'Tofu', 'One line description of project.',
'Miscellaneous'),
]
# -- Options for Epub output -------------------------------------------------
# Bibliographic Dublin Core info.
epub_title = project
# The unique identifier of the text. This can be a ISBN number
# or the project homepage.
#
# epub_identifier = ''
# A unique identification for the text.
#
# epub_uid = ''
# A list of files that should not be packed into the epub file.
epub_exclude_files = ['search.html']
# -- Extension configuration -------------------------------------------------
# -- Options for intersphinx extension ---------------------------------------
# Example configuration for intersphinx: refer to the Python standard library.
intersphinx_mapping = {'https://docs.python.org/': None}
# -- Options for todo extension ----------------------------------------------
# If true, `todo` and `todoList` produce output, else they produce nothing.
todo_include_todos = True
tofu-0.12.0/docs/source/index.rst 0000664 0000000 0000000 00000001207 14237137211 0016641 0 ustar 00root root 0000000 0000000 .. Tofu documentation master file, created by
sphinx-quickstart on Fri Aug 14 17:29:07 2020.
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.
Welcome to Tofu's documentation!
================================
.. toctree::
:maxdepth: 2
:caption: Contents:
Application Programming Interface
=================================
.. toctree::
:maxdepth: 2
api/preprocessing
api/genreco
api/util
Usage
=====
.. toctree::
:maxdepth: 2
usage/genreco
usage/flow
Indices and tables
==================
* :ref:`genindex`
* :ref:`modindex`
* :ref:`search`
tofu-0.12.0/docs/source/usage/ 0000775 0000000 0000000 00000000000 14237137211 0016104 5 ustar 00root root 0000000 0000000 tofu-0.12.0/docs/source/usage/flow.rst 0000664 0000000 0000000 00000027023 14237137211 0017611 0 ustar 00root root 0000000 0000000 Flow: Visual Graph Creation
===========================
You can use command ``tofu flow`` to start a graphical user interface in which UFO
tasks are represented as nodes which you can connect together. Once you have
created your flow you can execute it.
Nodes
-----
An operation on data is represented by a node in a flow. A node has inputs and
outputs, which have data types. An input or output of a node is
represented by a `port`, which is a circle on the left of the node in case of an
input and on the right in case of output. Every port has a data type which is
represented by color. There are two data types:
- `UFO`: you can connect all UFO nodes together
- `Array`: a numpy array which comes out from UFO's ``memory_out`` node and may be
used to visualize the processing result by ``image_viewer``
A node may have properties (almost all UFO nodes do, e.g. ``path`` in the
``read`` node) which are listed and can be set inside the node. If you hover the
mouse over a property field, a tooltip will be shown describing that property.
When you right click on a node which holds properties, a context menu pops up
and let's you choose which properties you want to be visible and which not. Some
nodes, like ``general_backproject`` have many properties, many of which may be
considered `expert` options which are not needed most of the times. By hiding
these properties, you can avoid clutter. There is a pre-defined set of
properties, which are shown by default. When you create a node in the scene,
this setting is applied and you can check which properties are hidden by default
by clicking on a node right after its creation. In case a node doesn't have
properties, right click either doesn't take effect or pops up a context menu
relevant for that node. E.g. ``image_viewer``'s context menu allows you to
configure the viewer's behavior.
.. double click
A node might implement an action on a double click, e.g. ``read`` node opens a
dialog allowing you to choose the data ``path``, ``image_viewer`` pops up an
external image window which can be enlarged and put on another display for
convenience. Current nodes which implement double clicks are:
- ``Composite``: opens a new window with a scene displaying internal composite
nodes
- ``read``: opens a dialog which allows you to choose the input ``path``
- ``write``: opens a dialog which allows you to choose the output ``filename``
- ``image_viewer``: opens the image in a new window
.. auto fill
``read`` node currently supports an `auto fill` option, which may be invoked via
the main menu bar. The node sets its ``number`` property to the number of
detected images found in the specified ``path``.
UFO Nodes
~~~~~~~~~
An UFO node represents an `UFO task` and holds properties which are the
Properties of this `UFO task`. Please check the `UFO Filters Reference
`_ for the complete list of UFO tasks and
their properties. When you create an UFO node, its properties are the default
properties of the encapsulated UFO task.
Composite Nodes
~~~~~~~~~~~~~~~
In order to reduce clutter, you can combine several nodes in a composite node
(main menu's `Nodes->Create Composite`) and you can also nest composites, i.e.
have a composite node and create another composite node with the first one
inside. Internal nodes are listed as groups in the composite node in the scene
and similarly to property nodes, you can show and hide different internal nodes
from the listing. The input and output ports of a composite node are the ports
of its internal nodes which are not connected at the time of composite node's
creation.
A double click on a composite node opens its internal nodes in a separate
window, where you can edit their properties but you can't add new nodes or
change connections. You can open this window also by pressing `Nodes->Edit
Composite` in the main menu.
In order to store a composite node for later usage, you can export it into a
file via the main menu's `Nodes->Export Composite`. You can import composite
node definitions by `Nodes->Import Composites`, which are then available in the
flow scene's context menu in the `Composite` category.
There are several pre-defined composite nodes available via the scene's context
menu (category `Composite`), they are:
- ``CFlatFieldCorrect`` encapsulates readers and averagers and the
``flat_field_correct`` node itself
- ``CPhaseRetrieve`` encapsulates padding, fourier tranformation and the phase
retrieval itself
General Backproject
~~~~~~~~~~~~~~~~~~~
This is a versatile back projection node which can reconstruct tomographic,
laminographic, parallel and cone beam data. It has one parameter which is not
part of the UFO task, ``slice-memory-coeff``. This parameter sets the fraction
(0 - 1) of a graphic card's memory which will be used to store the reconstructed
volume. If you are working with graphic cards which have other processes running
on them and these processes use a lot of memory, then you might need to reduce
this parameter.
Phase Retrieval
~~~~~~~~~~~~~~~
``retrieve_phase`` node may have varying number of inputs in order to support
multi-distance phase retrieval. You specify the number of inputs in a dialog
when you create the node. If you specify more than one input, the retrieval
method will be the multidistance contrast transfer function and the ``method``
field will be fixed to `ctf_multidistance`. In this case, fields ``distance-x``
and ``distance-y`` will be disabled. If you specify one input, you may choose
different methods via the ``method`` field. In this case, you can either specify
one value in the ``distance`` field, or specify separate distances for `x` and
`y` directions via ``distance-x`` and ``distance-y`` fields (they take
precedence over ``distance`` field in case they are both non-zero).
Image Viewer
~~~~~~~~~~~~
``image_viewer`` lets you display the results of your flow. It is composed of
the image itself and three text boxes with sliders, which allow you to specify
the image index shown, the black point and white point. In case only one image
is input the first slider is hidden. Right click on the node opens a context
menu which allows you to reset the black and white points (`Reset`), set them
automatically (`Auto Levels`) and specify whether they should be automatically
adjusted when new images are on input or left unchanged (`Auto Levels on New
Image`). Double click opens the image in a new window by using the PyQtGraph_
library. In case a separate window is open, image index, black and white point
settings can be set eigher in the flow node or in the window and they are
reflected in both the node and the window.
Flows
-----
On right click in the flow scene a context menu will pop up and you will be able
to add nodes. Then you can connect them by dragging a node's output port into
another node's input port if those ports have the same data type, which are
distinguished by port colors. By connecting node ports you create your flow
which you may later execute. Every node in the scene must have a unique
`caption`, so when you create a ``read`` node, the caption will be ``Read``,
when you create another ``read`` node, the caption will be ``Read 2`` and so on.
This is important for setting property links explained below.
The roots of the flow in the scene must be UFO nodes and leaves may have `UFO` or
`Array` type. It is not possible to go from `UFO` to `Array` and back to `UFO`,
i.e. the `UFO` portion of the flow in the scene must be one contiguous component
of the flow. There may be only one flow in the scene and it must be completely
connected (there can't be disconnected ports, e.g. ``flat_field_correct``'s
``darks`` port).
You can delete the current flow by pressing `Flow->New`, you can save a flow
into a flow file (.flow) by `Flow->Save` and open such files by `Flow->Open`.
Property Links
--------------
A property of a node might depend on another node's property, e.g. the number of
dimensions of an ``ifft`` node depends on the number of dimensions of the
predecessing ``fft`` task. In order to reduce the number of properties you need
to set, you can `link` properties together, i.e. when you set one node's
property, all the linked nodes' properties will be updated (e.g. when you change
the number of dimensions of an ``ifft`` node, the number of dimensions of the
linked ``fft`` node will be updated as well.
You can create property links in the `Property Links` window (open via main menu
bar's `View` field). At the top of the window, there is a tree view of the nodes
in the scene. Its items are the nodes in the flow scene, and in case there are
composites, they are listed recursively. The last level of the view are the
properties of the nodes in the flow. You can drag these properties into the list
in the second half of the window to start creating links. If you drag a property
to a new row or a row doesn't exist yet, it is automatically added. If you drag
a property into an existing row (over an existing cell), it is appended to this
row and a link is created. Links are allowed only for properties with compatible
data types, e.g. you cannot link ``read``'s ``path`` (a string) to ``fft``'s
``dimensions`` (a number). Also keep in mind that nodes which are able to
process batches have their fields which are responsible for receiving different
batches (e.g. ``number`` of the ``memory_out`` node) have string data type (so
that you can type `{region}` inside)
Execution
---------
Execution of the flow starts with executing the UFO part of the flow, and if
there is a ``memory_out`` and subsequent nodes, they get the result of the
UFO processing as the batches are finished (or just one batch if no
batch-capable nodes are in the flow). You start it by invoking main menu bar's
`Flow->run` action. You can abort the execution but invoking `Flow->abort`.
Batch Processing
~~~~~~~~~~~~~~~~
Some nodes require a lot of GPU memory and they can't process all the input data
at once (e.g. ``general_backproject``). Based on your system, they can split the
work on their own and tell the execution mechanism to run multiple batches. If
your system has multiple GPUs, ``tofu flow`` may create several batches and each
of these batches may be executed on one or more cards in your system.
Currently, only *one* batch processing task is allowed in the flow and only
``general_backproject`` supports batch processing.
In case your flow contains a node which is able to produce batches, then your
consumer nodes must be able to process batches and they must be notified of the
fact that they will get more batches on input. Currently, ``write`` and
``memory_out`` support batches and this is how you set them up for it:
- ``write``: ``filename`` must contain `{region}` somewhere in it, e.g.
`slices-{region}.tif`
- ``memory_out``: ``number`` field must be set to `{region}`
The `{region}` template is then replaced by the current batch identifier provided by the
producer node which is capable of batch processing, e.g. `slices-0.tif`,
`slices-100.tif` and so on.
If there is no node capable of producing batches, this is how you set them up
for normal, non-batch processing:
- ``write``: ``filename`` field set to normal file name, e.g. `slices.tif`
- ``memory_out``: ``number`` field set to the number of input images
Python Console
--------------
Main menu's `View->Python Console` opens up a Python interpreter console with
attribute ``scene`` set to the flow scene, which allows you to interact with the
nodes programatically, see `qtpynodeeditor docs
`_
more details on flow scene functionality.
.. _PyQtGraph: http://www.pyqtgraph.org/
.. _qtpynodeeditor:
tofu-0.12.0/docs/source/usage/genreco.rst 0000664 0000000 0000000 00000003346 14237137211 0020266 0 ustar 00root root 0000000 0000000 General 3D Reconstruction
=========================
You can use command ``tofu reco`` to reconstruct paralell/cone beam
tomography/laminography data. The algorithm is filtered back projection for
parallel beam data and `Feldkamp `_
approach for cone beam data. It always reconstructs 2D slices in the plane
parallel to the beam direction. The third dimensions may be the vertical slice
position (the default) but can also be one of the geometrical parameters in order to find
their best values for the final reconstruction (see ``tofu reco --help`` and
check the ``--z-parameter`` entry for possible values). Angular input values are
in degrees.
To reconstruct slices -100, 100 with the step size 0.5 around the center which
is defined as 1008.5 from 1500 projections acquired over 180 degrees stored in
``projs.tif``, with rotation axis in pixel 951 one would do::
tofu reco --projections projs.tif --number 1500 --center-position-x 951
--overall-angle 180 --center-position-z 1008.5 --region=-100,100,0.5
--output slices.tif
To scan the roll angle around -2, 2 degrees with step 0.1 degree, one can use
the following command::
tofu reco --projections projs.tif --number 1500 --overall-angle 180
--center-position-x 951 --center-position-z 1008.5 --z-parameter
detector-angle-y --region=-2,2,0.1 --output detector-angle-y-scan.tif
--disable-projection-crop
To scan the rotation axis region from pixel 940 to pixel 960 with step 0.5
pixels, (the ``center-position-x`` parameter), one can use::
tofu reco --projections projs.tif --number 1500 --overall-angle 180
--center-position-z 1008.5 --z-parameter center-position-x
--region=940,960,0.5 --output center-position-x-scan.tif
tofu-0.12.0/requirements-flow-tests.txt 0000664 0000000 0000000 00000000021 14237137211 0020132 0 ustar 00root root 0000000 0000000 pytest
pytest-qt
tofu-0.12.0/requirements-flow.txt 0000664 0000000 0000000 00000000106 14237137211 0016776 0 ustar 00root root 0000000 0000000 PyGObject
imageio
numpy
networkx
PyQt5
pyqtconsole
xdg
qtpynodeeditor
tofu-0.12.0/requirements-guis.txt 0000664 0000000 0000000 00000000106 14237137211 0016776 0 ustar 00root root 0000000 0000000 PyGObject
imageio
numpy
networkx
PyQt5
pyqtconsole
xdg
qtpynodeeditor
tofu-0.12.0/setup.py 0000664 0000000 0000000 00000001671 14237137211 0014267 0 ustar 00root root 0000000 0000000 from setuptools import setup, find_packages
from tofu import __version__
setup(
name='ufo-tofu',
python_requires='>=3',
version=__version__,
author='Matthias Vogelgesang',
author_email='matthias.vogelgesang@kit.edu',
url='http://github.com/ufo-kit/tofu',
license='LGPL',
packages=find_packages(),
package_data={'tofu': ['gui.ui', 'ez/GUI/default_settings.yaml'],
'tofu.flow': ['composites/*.cm', 'config.json']},
scripts=['bin/tofu'],
exclude_package_data={'': ['README.rst']},
install_requires= [
'PyGObject',
'imageio',
'numpy',
'networkx',
'PyQt5',
'pyqtconsole',
'xdg',
'qtpynodeeditor'
],
description="A fast, versatile and user-friendly image "\
"processing toolkit for computed tomography",
long_description=open('README.md').read(),
long_description_content_type='text/markdown',
)
tofu-0.12.0/tofu/ 0000775 0000000 0000000 00000000000 14237137211 0013525 5 ustar 00root root 0000000 0000000 tofu-0.12.0/tofu/__init__.py 0000664 0000000 0000000 00000000027 14237137211 0015635 0 ustar 00root root 0000000 0000000 __version__ = '0.12.0'
tofu-0.12.0/tofu/config.py 0000664 0000000 0000000 00000070531 14237137211 0015352 0 ustar 00root root 0000000 0000000 import argparse
import sys
import logging
import configparser as configparser
from collections import OrderedDict
from tofu.util import convert_filesize, restrict_value, tupleize, range_list
LOG = logging.getLogger(__name__)
NAME = "reco.conf"
SECTIONS = OrderedDict()
SECTIONS['general'] = {
'config': {
'default': NAME,
'type': str,
'help': "File name of configuration",
'metavar': 'FILE'},
'verbose': {
'default': False,
'help': 'Verbose output',
'action': 'store_true'},
'output': {
'default': 'result-%05i.tif',
'type': str,
'help': "Path to location or format-specified file path "
"for storing reconstructed slices",
'metavar': 'PATH'},
'output-bitdepth': {
'default': 32,
'type': restrict_value((0, None), dtype=int),
'help': "Bit depth of output, either 8, 16 or 32",
'metavar': 'BITDEPTH'},
'output-minimum': {
'default': None,
'type': float,
'help': "Minimum value that maps to zero",
'metavar': 'MIN'},
'output-maximum': {
'default': None,
'type': float,
'help': "Maximum input value that maps to largest output value",
'metavar': 'MAX'},
'output-bytes-per-file': {
'default': '128g',
'type': convert_filesize,
'help': "Maximum bytes per file (0=single-image output, otherwise multi-image output)\
, 'k', 'm', 'g', 't' suffixes can be used",
'metavar': 'BYTESPERFILE'},
'output-append': {
'default': False,
'action': 'store_true',
'help': 'Append images instead of overwriting existing files'},
'log': {
'default': None,
'type': str,
'help': "File name of optional log",
'metavar': 'FILE'},
'width': {
'default': None,
'type': restrict_value((0, None), dtype=int),
'help': "Input width"}}
SECTIONS['reading'] = {
'y': {
'type': restrict_value((0, None), dtype=int),
'default': 0,
'help': 'Vertical coordinate from where to start reading the input image'},
'height': {
'default': None,
'type': restrict_value((0, None), dtype=int),
'help': "Number of rows which will be read"},
'bitdepth': {
'default': 32,
'type': restrict_value((0, None), dtype=int),
'help': "Bit depth of raw files"},
'y-step': {
'type': restrict_value((0, None), dtype=int),
'default': 1,
'help': "Read every \"step\" row from the input"},
'start': {
'type': restrict_value((0, None), dtype=int),
'default': 0,
'help': 'Offset to the first read file'},
'number': {
'type': restrict_value((0, None), dtype=int),
'default': None,
'help': 'Number of files to read'},
'step': {
'type': restrict_value((0, None), dtype=int),
'default': 1,
'help': 'Read every \"step\" file'},
'resize': {
'type': restrict_value((0, None), dtype=int),
'default': None,
'help': 'Bin pixels before processing'},
'retries': {
'type': restrict_value((0, None), dtype=int),
'default': 0,
'metavar': 'NUMBER',
'help': 'How many times to wait for new files'},
'retry-timeout': {
'type': restrict_value((0, None), dtype=int),
'default': 0,
'metavar': 'TIME',
'help': 'How long to wait for new files per trial'}}
SECTIONS['flat-correction'] = {
'projections': {
'default': None,
'type': str,
'help': "Location with projections",
'metavar': 'PATH'},
'darks': {
'default': None,
'type': str,
'help': "Location with darks",
'metavar': 'PATH'},
'dark-scale': {
'default': 1,
'type': float,
'help': "Scaling dark"},
'reduction-mode': {
'default': "Average",
'type': str,
'help': "Flat-field correction options: Average (darks) or median (flats)"},
'fix-nan-and-inf': {
'default': False,
'help': "Fix nan and inf",
'action': 'store_true'},
'flats': {
'default': None,
'type': str,
'help': "Location with flats",
'metavar': 'PATH'},
'flats2': {
'default': None,
'type': str,
'help': "Location with flats 2 for interpolation correction",
'metavar': 'PATH'},
'flat-scale': {
'default': 1,
'type': float,
'help': "Scaling flat"},
'absorptivity': {
'default': False,
'action': 'store_true',
'help': 'Do absorption correction'}}
SECTIONS['retrieve-phase'] = {
'retrieval-method': {
'choices': ['tie', 'ctf', 'qp', 'qp2'],
'default': 'tie',
'help': "Phase retrieval method"},
'energy': {
'default': None,
'type': float,
'help': "X-ray energy [keV]"},
'propagation-distance': {
'default': None,
'type': tupleize(),
'help': ("Sample <-> detector distance (if one value, then use the same for x and y"
"direction, otherwise first specifies x and second y direction) [m]")},
'pixel-size': {
'default': 1e-6,
'type': float,
'help': "Pixel size [m]"},
'regularization-rate': {
'default': 2,
'type': float,
'help': "Regularization rate (typical values between [2, 3])"},
'delta': {
'default': None,
'type': float,
'help': "Real part of the complex refractive index of the material. "
"If specified, phase retrieval returns projected thickness, "
"if not, it returns phase"},
'retrieval-padded-width': {
'default': 0,
'type': restrict_value((0, None), dtype=int),
'help': "Padded width used for phase retrieval"},
'retrieval-padded-height': {
'default': 0,
'type': restrict_value((0, None), dtype=int),
'help': "Padded height used for phase retrieval"},
'retrieval-padding-mode': {
'choices': ['none', 'clamp', 'clamp_to_edge', 'repeat', 'mirrored_repeat'],
'default': 'clamp_to_edge',
'help': "Padded values assignment"},
'thresholding-rate': {
'default': 0.01,
'type': float,
'help': "Thresholding rate (typical values between [0.01, 0.1])"},
'frequency-cutoff': {
'default': 1e30,
'type': float,
'help': "Phase retrieval frequency cutoff [rad]"}}
SECTIONS['sinos'] = {
'pass-size': {
'type': restrict_value((0, None), dtype=int),
'default': 0,
'help': 'Number of sinograms to process per pass'}}
SECTIONS['reconstruction'] = {
'sinograms': {
'default': None,
'type': str,
'help': "Location with sinograms",
'metavar': 'PATH'},
'angle': {
'default': None,
'type': float,
'help': "Angle step between projections in radians"},
'enable-tracing': {
'default': False,
'help': "Enable tracing and store result in .PID.json",
'action': 'store_true'},
'remotes': {
'default': None,
'type': str,
'help': "Addresses to remote ufo-nodes",
'nargs': '+'},
'projection-filter': {
'default': 'ramp-fromreal',
'type': str,
'help': "Projection filter",
'choices': ['none', 'ramp', 'ramp-fromreal', 'butterworth', 'faris-byer', 'bh3', 'hamming']},
'projection-filter-cutoff': {
'default': 0.5,
'type': float,
'help': "Relative cutoff frequency"},
'projection-padding-mode': {
'choices': ['none', 'clamp', 'clamp_to_edge', 'repeat', 'mirrored_repeat'],
'default': 'clamp_to_edge',
'help': "Padded values assignment"}}
SECTIONS['tomographic-reconstruction'] = {
'axis': {
'default': None,
'type': float,
'help': "Axis position"},
'dry-run': {
'default': False,
'help': "Reconstruct without writing data",
'action': 'store_true'},
'offset': {
'default': 0.0,
'type': float,
'help': "Angle offset of first projection in radians"},
'method': {
'default': 'fbp',
'type': str,
'help': "Reconstruction method",
'choices': ['fbp', 'dfi', 'sart', 'sirt', 'sbtv', 'asdpocs']}}
SECTIONS['laminographic-reconstruction'] = {
'angle': {
'default': None,
'type': float,
'help': "Angle step between projections in radians"},
'dry-run': {
'default': False,
'help': "Reconstruct without writing data",
'action': 'store_true'},
'axis': {
'default': None,
'required': True,
'type': tupleize(num_items=2),
'help': "Axis position"},
'x-region': {
'default': "0,-1,1",
'type': tupleize(num_items=3, conv=int),
'help': "X region as from,to,step"},
'y-region': {
'default': "0,-1,1",
'type': tupleize(num_items=3, conv=int),
'help': "Y region as from,to,step"},
'z': {
'default': 0,
'type': int,
'help': "Z coordinate of the reconstructed slice"},
'z-parameter': {
'default': 'z',
'type': str,
'choices': ['z', 'x-center', 'lamino-angle', 'roll-angle'],
'help': "Parameter to vary along the reconstructed z-axis"},
'region': {
'default': "0,-1,1",
'type': tupleize(num_items=3),
'help': "Z-axis parameter region as from,to,step"},
'overall-angle': {
'default': None,
'type': float,
'help': "The total angle over which projections were taken in degrees"},
'lamino-angle': {
'default': None,
'required': True,
'type': float,
'help': "The laminographic angle in degrees"},
'roll-angle': {
'default': 0.0,
'type': float,
'help': "Sample angular misalignment to the side (roll) in degrees, positive angles mean\
clockwise misalignment"},
'slices-per-device': {
'default': None,
'type': restrict_value((0, None), dtype=int),
'help': "Number of slices computed by one computing device"},
'only-bp': {
'default': False,
'action': 'store_true',
'help': "Do only backprojection with no other processing steps"},
'lamino-padding-mode': {
'choices': ['none', 'clamp', 'clamp_to_edge', 'repeat', 'mirrored_repeat'],
'default': 'clamp',
'help': "Padded values assignment for the filtered projection"}}
SECTIONS['fbp'] = {
'crop-width': {
'default': None,
'type': restrict_value((0, None), dtype=int),
'help': "Width of final slice"},
'projection-crop-after': {
'choices': ['filter', 'backprojection'],
'default': 'backprojection',
'help': "Whether to crop projections after filtering (can cause truncation "
"artifacts) or after backprojection"}}
SECTIONS['dfi'] = {
'oversampling': {
'default': None,
'type': restrict_value((0, None), dtype=int),
'help': "Oversample factor"}}
SECTIONS['ir'] = {
'num-iterations': {
'default': 10,
'type': restrict_value((0, None), dtype=int),
'help': "Maximum number of iterations"}}
SECTIONS['sart'] = {
'relaxation-factor': {
'default': 0.25,
'type': float,
'help': "Relaxation factor"}}
SECTIONS['sbtv'] = {
'lambda': {
'default': 0.1,
'type': float,
'help': "Lambda"},
'mu': {
'default': 0.5,
'type': float,
'help': "mu"}}
SECTIONS['gui'] = {
'enable-cropping': {
'default': False,
'help': "Enable cropping width",
'action': 'store_true'},
'show-2d': {
'default': False,
'help': "Show 2D slices with pyqtgraph",
'action': 'store_true'},
'show-3d': {
'default': False,
'help': "Show 3D slices with pyqtgraph",
'action': 'store_true'},
'last-dir': {
'default': '.',
'type': str,
'help': "Path of the last used directory",
'metavar': 'PATH'},
'deg0': {
'default': '.',
'type': str,
'help': "Location with 0 deg projection",
'metavar': 'PATH'},
'deg180': {
'default': '.',
'type': str,
'help': "Location with 180 deg projection",
'metavar': 'PATH'},
'ffc-correction': {
'default': False,
'help': "Enable darks or flats correction",
'action': 'store_true'},
'num-flats': {
'default': 0,
'type': int,
'help': "Number of flats for ffc correction."}}
SECTIONS['estimate'] = {
'estimate-method': {
'type': str,
'default': 'correlation',
'help': 'Rotation axis estimation algorithm',
'choices': ['reconstruction', 'correlation']}}
SECTIONS['perf'] = {
'num-runs': {
'default': 3,
'type': restrict_value((0, None), dtype=int),
'help': "Number of runs"},
'width-range': {
'default': '1024',
'type': range_list,
'help': "Width or range of widths of generated projections"},
'height-range': {
'default': '1024',
'type': range_list,
'help': "Height or range of heights of generated projections"},
'num-projection-range': {
'default': '512',
'type': range_list,
'help': "Number or range of number of projections"}}
SECTIONS['preprocess'] = {
'transpose-input': {
'default': False,
'action': 'store_true',
'help': "Transpose projections before they are backprojected (after phase retrieval)"},
'projection-filter': {
'default': 'ramp-fromreal',
'type': str,
'help': "Projection filter",
'choices': ['none', 'ramp', 'ramp-fromreal', 'butterworth', 'faris-byer', 'bh3', 'hamming']},
'projection-filter-cutoff': {
'default': 0.5,
'type': float,
'help': "Relative cutoff frequency"},
'projection-filter-scale': {
'default': 1.,
'type': float,
'help': "Multiplicative factor of the projection filter"},
'projection-padding-mode': {
'choices': ['none', 'clamp', 'clamp_to_edge', 'repeat', 'mirrored_repeat'],
'default': 'clamp_to_edge',
'help': "Padded values assignment"},
'projection-crop-after': {
'choices': ['filter', 'backprojection'],
'default': 'backprojection',
'help': "Whether to crop projections after filtering (can cause truncation "
"artifacts) or after backprojection"}}
SECTIONS['cone-beam-weight'] = {
'source-position-y': {
'default': "-Inf",
'type': tupleize(dtype=list),
'help': "Y source position (along beam direction) in global coordinates [pixels]"},
'detector-position-y': {
'default': "0",
'type': tupleize(dtype=list),
'help': "Y detector position (along beam direction) in global coordinates [pixels]"},
'center-position-x': {
'default': None,
'type': tupleize(),
'help': "X rotation axis position on a projection"},
'center-position-z': {
'default': None,
'type': tupleize(),
'help': "Z rotation axis position on a projection"},
'axis-angle-x': {
'default': "0",
'type': tupleize(dtype=list),
'help': "Rotation axis rotation around the x axis"
"(laminographic angle, 0 = tomography) [deg]"}}
SECTIONS['general-reconstruction'] = {
'enable-tracing': {
'default': False,
'help': "Enable tracing and store result in .PID.json",
'action': 'store_true'},
'disable-cone-beam-weight': {
'default': False,
'action': 'store_true',
'help': "Disable cone beam weighting"},
'slice-memory-coeff': {
'default': 0.8,
'type': restrict_value((0.01, 0.95)),
'help': "Portion of the GPU memory used for slices (from 0.01 to 0.9) [fraction]. "
"The total amount of consumed memory will be larger depending on the "
"complexity of the graph. In case of OpenCL memory allocation errors, "
"try reducing this value."},
'num-gpu-threads': {
'default': 1,
'type': restrict_value((1, None), dtype=int),
'help': "Number of parallel reconstruction threads on one GPU"},
'disable-projection-crop': {
'default': False,
'action': 'store_true',
'help': "Disable automatic cropping of projections computed from volume region"},
'dry-run': {
'default': False,
'help': "Reconstruct without reading or writing data",
'action': 'store_true'},
'data-splitting-policy': {
'default': 'one',
'type': str,
'help': "'one': one GPU should process as many slices as possible, "
"'many': slices should be spread across as many GPUs as possible",
'choices': ['one', 'many']},
'projection-margin': {
'default': 0,
'type': restrict_value((0, None), dtype=int),
'help': "By optimization of the read projection region, the read region will be "
"[y - margin, y + height + margin]"},
'slices-per-device': {
'default': None,
'type': restrict_value((0, None), dtype=int),
'help': "Number of slices computed by one computing device"},
'gpus': {
'default': None,
'nargs': '+',
'type': int,
'help': "GPUs with these indices will be used (0-based)"},
'burst': {
'default': None,
'type': restrict_value((0, None), dtype=int),
'help': "Number of projections processed per kernel invocation"},
'x-region': {
'default': "0,-1,1",
'type': tupleize(num_items=3),
'help': "x region as from,to,step"},
'y-region': {
'default': "0,-1,1",
'type': tupleize(num_items=3),
'help': "y region as from,to,step"},
'z': {
'default': 0,
'type': int,
'help': "z coordinate of the reconstructed slice"},
'z-parameter': {
'default': 'z',
'type': str,
'choices': ['axis-angle-x', 'axis-angle-y', 'axis-angle-z',
'volume-angle-x', 'volume-angle-y', 'volume-angle-z',
'detector-angle-x', 'detector-angle-y', 'detector-angle-z',
'detector-position-x', 'detector-position-y', 'detector-position-z',
'source-position-x', 'source-position-y', 'source-position-z',
'center-position-x', 'center-position-z', 'z'],
'help': "Parameter to vary along the reconstructed z-axis"},
'region': {
'default': "0,1,1",
'type': tupleize(num_items=3),
'help': "z axis parameter region as from,to,step"},
'source-position-x': {
'default': "0",
'type': tupleize(dtype=list),
'help': "X source position (horizontal) in global coordinates [pixels]"},
'source-position-z': {
'default': "0",
'type': tupleize(dtype=list),
'help': "Z source position (vertical) in global coordinates [pixels]"},
'detector-position-x': {
'default': "0",
'type': tupleize(dtype=list),
'help': "X detector position (horizontal) in global coordinates [pixels]"},
'detector-position-z': {
'default': "0",
'type': tupleize(dtype=list),
'help': "Z detector position (vertical) in global coordinates [pixels]"},
'detector-angle-x': {
'default': "0",
'type': tupleize(dtype=list),
'help': "Detector rotation around the x axis (horizontal) [deg]"},
'detector-angle-y': {
'default': "0",
'type': tupleize(dtype=list),
'help': "Detector rotation around the y axis (along beam direction) [deg]"},
'detector-angle-z': {
'default': "0",
'type': tupleize(dtype=list),
'help': "Detector rotation around the z axis (vertical) [deg]"},
'axis-angle-y': {
'default': "0",
'type': tupleize(dtype=list),
'help': "Rotation axis rotation around the y axis (along beam direction) [deg]"},
'axis-angle-z': {
'default': "0",
'type': tupleize(dtype=list),
'help': "Rotation axis rotation around the z axis (vertical) [deg]"},
'volume-angle-x': {
'default': "0",
'type': tupleize(dtype=list),
'help': "Volume rotation around the x axis (horizontal) [deg]"},
'volume-angle-y': {
'default': "0",
'type': tupleize(dtype=list),
'help': "Volume rotation around the y axis (along beam direction) [deg]"},
'volume-angle-z': {
'default': "0",
'type': tupleize(dtype=list),
'help': "Volume rotation around the z axis (vertical) [deg]"},
'compute-type': {
'default': 'float',
'type': str,
'help': "Data type for performing kernel math operations",
'choices': ['half', 'float', 'double']},
'result-type': {
'default': 'float',
'type': str,
'help': "Data type for storing the intermediate gray value for a voxel "
"from various rotation angles",
'choices': ['half', 'float', 'double']},
'store-type': {
'default': 'float',
'type': str,
'help': "Data type of the output volume",
'choices': ['half', 'float', 'double', 'uchar', 'ushort', 'uint']},
'overall-angle': {
'default': None,
'type': float,
'help': "The total angle over which projections were taken in degrees"},
'genreco-padding-mode': {
'choices': ['none', 'clamp', 'clamp_to_edge', 'repeat', 'mirrored_repeat'],
'default': 'clamp',
'help': "Padded values assignment for the filtered projection"},
'slice-gray-map': {
'default': "0,0",
'type': tupleize(num_items=2, conv=float),
'help': "Minimum and maximum gray value mapping if store-type is integer-based"}}
SECTIONS['find-large-spots'] = {
'images': {
'default': None,
'type': str,
'help': "Location with input images",
'metavar': 'PATH'},
'gauss-sigma': {
'default': 0.0,
'type': float,
'help': "Gaussian sigma for removing low frequencies (filter will be 1 - gauss window)"},
'blurred-output': {
'default': None,
'type': str,
'help': "Path where to store the blurred input"},
'spot-threshold': {
'default': 0.0,
'type': float,
'help': "Pixels with grey value larger than this are considered as spots"},
'spot-threshold-mode': {
'default': 'absolute',
'type': str,
'help': "Pixels must be either \"below\", \"above\" the spot threshold, or \
their \"absolute\" value can be compared",
'choices': ['below', 'above', 'absolute']},
'grow-threshold': {
'default': 0.0,
'type': float,
'help': "Spot growing threshold, if 0 it will be set to FWTM times noise standard deviation"},
'find-large-spots-padding-mode': {
'choices': ['none', 'clamp', 'clamp_to_edge', 'repeat', 'mirrored_repeat'],
'default': 'repeat',
'help': "Padded values assignment for the filtered input image"},
}
TOMO_PARAMS = ('flat-correction', 'reconstruction', 'tomographic-reconstruction', 'fbp', 'dfi', 'ir', 'sart', 'sbtv')
PREPROC_PARAMS = ('preprocess', 'cone-beam-weight', 'flat-correction', 'retrieve-phase')
LAMINO_PARAMS = PREPROC_PARAMS + ('laminographic-reconstruction',)
GEN_RECO_PARAMS = PREPROC_PARAMS + ('general-reconstruction',)
NICE_NAMES = ('General', 'Input', 'Flat field correction', 'Phase retrieval',
'Sinogram generation', 'General reconstruction', 'Tomographic reconstruction',
'Laminographic reconstruction', 'Filtered backprojection',
'Direct Fourier Inversion', 'Iterative reconstruction',
'SART', 'SBTV', 'GUI settings', 'Estimation', 'Performance',
'Preprocess', 'Cone beam weight', 'General reconstruction', 'Find large spots')
def get_config_name():
"""Get the command line --config option."""
name = ''
for i, arg in enumerate(sys.argv):
if arg.startswith('--config'):
if arg == '--config':
return sys.argv[i + 1]
else:
name = sys.argv[i].split('--config')[1]
if name[0] == '=':
name = name[1:]
return name
return name
def parse_known_args(parser, subparser=False):
"""
Parse arguments from file and then override by the ones specified on the
command line. Use *parser* for parsing and is *subparser* is True take into
account that there is a value on the command line specifying the subparser.
"""
if len(sys.argv) > 1:
subparser_value = [sys.argv[1]] if subparser else []
config_values = config_to_list(config_name=get_config_name())
values = subparser_value + config_values + sys.argv[1:]
args = None
if config_values:
args = parser.parse_known_args(args=subparser_value + config_values)[0]
parser.parse_args(args=sys.argv[1:], namespace=args)
else:
values = ""
return parser.parse_known_args(values)[0]
def config_to_list(config_name=''):
"""
Read arguments from config file and convert them to a list of keys and
values as sys.argv does when they are specified on the command line.
*config_name* is the file name of the config file.
"""
result = []
config = configparser.ConfigParser()
if not config.read([config_name]):
return []
for section in SECTIONS:
for name, opts in ((n, o) for n, o in list(SECTIONS[section].items()) if config.has_option(section, n)):
value = config.get(section, name)
if value != '' and value != 'None':
action = opts.get('action', None)
if action == 'store_true' and value == 'True':
# Only the key is on the command line for this action
result.append('--{}'.format(name))
if not action == 'store_true':
if opts.get('nargs', None) == '+':
result.append('--{}'.format(name))
result.extend((v.strip() for v in value.split(',')))
else:
result.append('--{}={}'.format(name, value))
return result
class Params(object):
def __init__(self, sections=()):
self.sections = sections + ('general', 'reading')
def add_parser_args(self, parser):
for section in self.sections:
for name in sorted(SECTIONS[section]):
opts = SECTIONS[section][name]
parser.add_argument('--{}'.format(name), **opts)
def add_arguments(self, parser):
self.add_parser_args(parser)
return parser
def get_defaults(self):
parser = argparse.ArgumentParser()
self.add_arguments(parser)
return parser.parse_args('')
def write(config_file, args=None, sections=None):
"""
Write *config_file* with values from *args* if they are specified,
otherwise use the defaults. If *sections* are specified, write values from
*args* only to those sections, use the defaults on the remaining ones.
"""
config = configparser.ConfigParser()
for section in SECTIONS:
config.add_section(section)
for name, opts in list(SECTIONS[section].items()):
if args and sections and section in sections and hasattr(args, name.replace('-', '_')):
value = getattr(args, name.replace('-', '_'))
if isinstance(value, list):
value = ', '.join(value)
else:
value = opts['default'] if opts['default'] != None else ''
prefix = '# ' if value == '' else ''
if name != 'config':
config.set(section, prefix + name, value)
with open(config_file, 'wb') as f:
config.write(f)
def log_values(args):
"""Log all values set in the args namespace.
Arguments are grouped according to their section and logged alphabetically
using the DEBUG log level thus --verbose is required.
"""
args = args.__dict__
for section, name in zip(SECTIONS, NICE_NAMES):
entries = sorted((k for k in list(args.keys()) if k.replace('_', '-') in SECTIONS[section]))
if entries:
LOG.debug(name)
for entry in entries:
value = args[entry] if args[entry] is not None else "-"
LOG.debug(" {:<16} {}".format(entry, value))
tofu-0.12.0/tofu/ez/ 0000775 0000000 0000000 00000000000 14237137211 0014143 5 ustar 00root root 0000000 0000000 tofu-0.12.0/tofu/ez/GUI/ 0000775 0000000 0000000 00000000000 14237137211 0014567 5 ustar 00root root 0000000 0000000 tofu-0.12.0/tofu/ez/GUI/Advanced/ 0000775 0000000 0000000 00000000000 14237137211 0016274 5 ustar 00root root 0000000 0000000 tofu-0.12.0/tofu/ez/GUI/Advanced/__init__.py 0000664 0000000 0000000 00000000000 14237137211 0020373 0 ustar 00root root 0000000 0000000 tofu-0.12.0/tofu/ez/GUI/Advanced/advanced.py 0000664 0000000 0000000 00000014120 14237137211 0020411 0 ustar 00root root 0000000 0000000 import logging
from PyQt5.QtWidgets import QGridLayout, QLabel, QGroupBox, QLineEdit
import tofu.ez.params as parameters
LOG = logging.getLogger(__name__)
class AdvancedGroup(QGroupBox):
"""
Advanced Tofu Reco settings
"""
def __init__(self):
super().__init__()
self.setTitle("Advanced TOFU Reconstruction Settings")
self.setStyleSheet("QGroupBox {color: green;}")
# LAMINO
self.lamino_group = QGroupBox("Extended Settings of Reconstruction Algorithms")
self.lamino_group.clicked.connect(self.set_lamino_group)
self.lamino_angle_label = QLabel("Laminographic angle ")
self.lamino_angle_entry = QLineEdit()
self.lamino_angle_entry.editingFinished.connect(self.set_lamino_angle)
self.overall_rotation_label = QLabel("Overall rotation range about CT Z-axis")
self.overall_rotation_entry = QLineEdit()
self.overall_rotation_entry.editingFinished.connect(self.set_overall_rotation)
self.center_position_z_label = QLabel("Center Position Z ")
self.center_position_z_entry = QLineEdit()
self.center_position_z_entry.editingFinished.connect(self.set_center_position_z)
self.axis_rotation_y_label = QLabel(
"Sample rotation about the beam Y-axis "
)
self.axis_rotation_y_entry = QLineEdit()
self.axis_rotation_y_entry.editingFinished.connect(self.set_rotation_about_beam)
# AUXILIARY FFC
self.dark_scale_label = QLabel("Dark scale ")
self.dark_scale_entry = QLineEdit()
self.dark_scale_entry.editingFinished.connect(self.set_dark_scale)
self.flat_scale_label = QLabel("Flat scale ")
self.flat_scale_entry = QLineEdit()
self.flat_scale_entry.editingFinished.connect(self.set_flat_scale)
self.set_layout()
def set_layout(self):
layout = QGridLayout()
self.lamino_group.setCheckable(True)
self.lamino_group.setChecked(False)
lamino_layout = QGridLayout()
lamino_layout.addWidget(self.lamino_angle_label, 0, 0)
lamino_layout.addWidget(self.lamino_angle_entry, 0, 1)
lamino_layout.addWidget(self.overall_rotation_label, 1, 0)
lamino_layout.addWidget(self.overall_rotation_entry, 1, 1)
lamino_layout.addWidget(self.center_position_z_label, 2, 0)
lamino_layout.addWidget(self.center_position_z_entry, 2, 1)
lamino_layout.addWidget(self.axis_rotation_y_label, 3, 0)
lamino_layout.addWidget(self.axis_rotation_y_entry, 3, 1)
self.lamino_group.setLayout(lamino_layout)
aux_group = QGroupBox("Auxiliary FFC Settings")
aux_group.setCheckable(True)
aux_group.setChecked(False)
aux_layout = QGridLayout()
aux_layout.addWidget(self.dark_scale_label, 0, 0)
aux_layout.addWidget(self.dark_scale_entry, 0, 1)
aux_layout.addWidget(self.flat_scale_label, 1, 0)
aux_layout.addWidget(self.flat_scale_entry, 1, 1)
aux_group.setLayout(aux_layout)
layout.addWidget(self.lamino_group)
layout.addWidget(aux_group)
self.setLayout(layout)
def init_values(self):
self.lamino_group.setChecked(False)
parameters.params['advanced_advtofu_extended_settings'] = False
self.lamino_angle_entry.setText("30")
parameters.params['advanced_advtofu_lamino_angle'] = 30
self.overall_rotation_entry.setText("360")
parameters.params['advanced_adv_tofu_z_axis_rotation'] = 360
self.center_position_z_entry.setText("")
parameters.params['advanced_advtofu_center_position_z'] = ""
self.axis_rotation_y_entry.setText("")
parameters.params['advanced_advtofu_y_axis_rotation'] = ""
self.dark_scale_entry.setText("")
parameters.params['advanced_advtofu_aux_ffc_dark_scale'] = ""
self.flat_scale_entry.setText("")
parameters.params['advanced_advtofu_aux_ffc_flat_scale'] = ""
def set_values_from_params(self):
self.lamino_group.setChecked(parameters.params['advanced_advtofu_extended_settings'])
self.lamino_angle_entry.setText(str(parameters.params['advanced_advtofu_lamino_angle']))
self.overall_rotation_entry.setText(str(parameters.params['advanced_adv_tofu_z_axis_rotation']))
self.center_position_z_entry.setText(str(parameters.params['advanced_advtofu_center_position_z']))
self.axis_rotation_y_entry.setText(str(parameters.params['advanced_advtofu_y_axis_rotation']))
self.dark_scale_entry.setText(str(parameters.params['advanced_advtofu_aux_ffc_dark_scale']))
self.flat_scale_entry.setText(str(parameters.params['advanced_advtofu_aux_ffc_flat_scale']))
def set_lamino_group(self):
LOG.debug("Lamino: " + str(self.lamino_group.isChecked()))
parameters.params['advanced_advtofu_extended_settings'] = bool(self.lamino_group.isChecked())
def set_lamino_angle(self):
LOG.debug(self.lamino_angle_entry.text())
parameters.params['advanced_advtofu_lamino_angle'] = str(self.lamino_angle_entry.text())
def set_overall_rotation(self):
LOG.debug(self.overall_rotation_entry.text())
parameters.params['advanced_adv_tofu_z_axis_rotation'] = str(self.overall_rotation_entry.text())
def set_center_position_z(self):
LOG.debug(self.center_position_z_entry.text())
parameters.params['advanced_advtofu_center_position_z'] = str(self.center_position_z_entry.text())
def set_rotation_about_beam(self):
LOG.debug(self.axis_rotation_y_entry.text())
parameters.params['advanced_advtofu_y_axis_rotation'] = str(self.axis_rotation_y_entry.text())
def set_dark_scale(self):
LOG.debug(self.dark_scale_entry.text())
parameters.params['advanced_advtofu_aux_ffc_dark_scale'] = str(self.dark_scale_entry.text())
def set_flat_scale(self):
LOG.debug(self.flat_scale_entry.text())
parameters.params['advanced_advtofu_aux_ffc_flat_scale'] = str(self.flat_scale_entry.text())
tofu-0.12.0/tofu/ez/GUI/Advanced/ffc.py 0000664 0000000 0000000 00000012376 14237137211 0017415 0 ustar 00root root 0000000 0000000 import logging
from PyQt5.QtWidgets import (
QGridLayout,
QLabel,
QGroupBox,
QLineEdit,
QCheckBox,
QRadioButton,
QHBoxLayout,
)
import tofu.ez.params as parameters
LOG = logging.getLogger(__name__)
class FFCGroup(QGroupBox):
"""
Flat Field Correction Settings
"""
def __init__(self):
super().__init__()
self.setTitle("Flat Field Correction")
self.setStyleSheet("QGroupBox {color: indigo;}")
self.method_label = QLabel("Method:")
self.average_rButton = QRadioButton("Average")
self.average_rButton.clicked.connect(self.set_method)
self.ssim_rButton = QRadioButton("SSIM")
self.ssim_rButton.clicked.connect(self.set_method)
self.eigen_rButton = QRadioButton("Eigen")
self.eigen_rButton.clicked.connect(self.set_method)
self.enable_sinFFC_checkbox = QCheckBox(
"Use Smart Intensity Normalization Flat Field Correction"
)
self.enable_sinFFC_checkbox.stateChanged.connect(self.set_sinFFC)
self.eigen_pco_repetitions_label = QLabel("Eigen PCO Repetitions")
self.eigen_pco_repetitions_entry = QLineEdit()
self.eigen_pco_repetitions_entry.editingFinished.connect(self.set_pcoReps)
self.eigen_pco_downsample_label = QLabel("Eigen PCO Downsample")
self.eigen_pco_downsample_entry = QLineEdit()
self.eigen_pco_downsample_entry.editingFinished.connect(self.set_pcoDowns)
self.downsample_label = QLabel("Downsample")
self.downsample_entry = QLineEdit()
self.downsample_entry.editingFinished.connect(self.set_downsample)
self.set_layout()
def set_layout(self):
layout = QGridLayout()
rbutton_layout = QHBoxLayout()
rbutton_layout.addWidget(self.method_label)
rbutton_layout.addWidget(self.eigen_rButton)
rbutton_layout.addWidget(self.average_rButton)
rbutton_layout.addWidget(self.ssim_rButton)
layout.addWidget(self.enable_sinFFC_checkbox, 0, 0)
layout.addItem(rbutton_layout, 1, 0, 1, 2)
layout.addWidget(self.eigen_pco_repetitions_label, 2, 0)
layout.addWidget(self.eigen_pco_repetitions_entry, 2, 1)
layout.addWidget(self.eigen_pco_downsample_label, 3, 0)
layout.addWidget(self.eigen_pco_downsample_entry, 3, 1)
layout.addWidget(self.downsample_label, 4, 0)
layout.addWidget(self.downsample_entry, 4, 1)
self.setLayout(layout)
def init_values(self):
self.eigen_rButton.setChecked(True)
self.average_rButton.setChecked(False)
self.ssim_rButton.setChecked(False)
parameters.params['advanced_ffc_method'] = "eigen"
self.enable_sinFFC_checkbox.setChecked(False)
self.eigen_pco_repetitions_entry.setText("4")
self.eigen_pco_downsample_entry.setText("2")
self.downsample_entry.setText("4")
def set_values_from_params(self):
self.enable_sinFFC_checkbox.setChecked(parameters.params['advanced_ffc_sinFFC'])
self.set_method_from_params()
self.eigen_pco_repetitions_entry.setText(str(parameters.params['advanced_ffc_eigen_pco_reps']))
self.eigen_pco_downsample_entry.setText(str(parameters.params['advanced_ffc_eigen_pco_downsample']))
self.downsample_entry.setText(str(parameters.params['advanced_ffc_downsample']))
def set_sinFFC(self):
logging.debug("sinFFC: " + str(self.enable_sinFFC_checkbox.isChecked()))
parameters.params['advanced_ffc_sinFFC'] = bool(self.enable_sinFFC_checkbox.isChecked())
def set_pcoReps(self):
logging.debug("PCO Reps: " + str(self.eigen_pco_repetitions_entry.text()))
parameters.params['advanced_ffc_eigen_pco_reps'] = str(self.eigen_pco_repetitions_entry.text())
def set_pcoDowns(self):
logging.debug("PCO Downsample: " + str(self.eigen_pco_downsample_entry.text()))
parameters.params['advanced_ffc_eigen_pco_downsample'] = str(self.eigen_pco_downsample_entry.text())
def set_downsample(self):
logging.debug("Downsample: " + str(self.downsample_entry.text()))
parameters.params['advanced_ffc_downsample'] = str(self.downsample_entry.text())
def set_method(self):
if self.eigen_rButton.isChecked():
logging.debug("Method: Eigen")
parameters.params['advanced_ffc_method'] = "eigen"
elif self.average_rButton.isChecked():
logging.debug("Method: Average")
parameters.params['advanced_ffc_method'] = "average"
elif self.ssim_rButton.isChecked():
logging.debug("Method: SSIM")
parameters.params['advanced_ffc_method'] = "ssim"
def set_method_from_params(self):
if parameters.params['advanced_ffc_method'] == 1:
self.eigen_rButton.setChecked(True)
self.average_rButton.setChecked(False)
self.ssim_rButton.setChecked(False)
elif parameters.params['advanced_ffc_method'] == 2:
self.eigen_rButton.setChecked(False)
self.average_rButton.setChecked(True)
self.ssim_rButton.setChecked(False)
elif parameters.params['advanced_ffc_method'] == 3:
self.eigen_rButton.setChecked(False)
self.average_rButton.setChecked(False)
self.ssim_rButton.setChecked(True)
tofu-0.12.0/tofu/ez/GUI/Advanced/nlmdn.py 0000664 0000000 0000000 00000040312 14237137211 0017756 0 ustar 00root root 0000000 0000000 import logging
import os
from shutil import rmtree
from PyQt5.QtWidgets import (
QGridLayout,
QLabel,
QGroupBox,
QLineEdit,
QCheckBox,
QPushButton,
QFileDialog,
QMessageBox,
)
from PyQt5.QtCore import Qt
import tofu.ez.params as parameters
from tofu.ez.main_nlm import main_tk
LOG = logging.getLogger(__name__)
class NLMDNGroup(QGroupBox):
"""
Non-local means de-noising settings
"""
def __init__(self):
super().__init__()
self.setTitle("Non-local-means Denoising")
self.setStyleSheet("QGroupBox {color: royalblue;}")
self.apply_to_reco_checkbox = QCheckBox("Automatically apply NLMDN to reconstructed slices")
self.apply_to_reco_checkbox.stateChanged.connect(self.set_apply_to_reco)
self.input_dir_button = QPushButton("Select input directory")
self.input_dir_button.clicked.connect(self.set_indir_button)
self.select_img_button = QPushButton("Select one image")
self.select_img_button.clicked.connect(self.select_image)
self.input_dir_entry = QLineEdit()
self.input_dir_entry.editingFinished.connect(self.set_indir_entry)
self.output_dir_button = QPushButton("Select output directory or filename pattern")
self.output_dir_button.clicked.connect(self.set_outdir_button)
self.save_bigtif_checkbox = QCheckBox("Save in bigtif container")
self.save_bigtif_checkbox.clicked.connect(self.set_save_bigtif)
self.output_dir_entry = QLineEdit()
self.output_dir_entry.editingFinished.connect(self.set_outdir_entry)
self.similarity_radius_label = QLabel("Radius for similarity search")
self.similarity_radius_entry = QLineEdit()
self.similarity_radius_entry.editingFinished.connect(self.set_rad_sim_entry)
self.patch_radius_label = QLabel("Radius of patches")
self.patch_radius_entry = QLineEdit()
self.patch_radius_entry.editingFinished.connect(self.set_rad_patch_entry)
self.smoothing_label = QLabel("Smoothing control parameter")
self.smoothing_entry = QLineEdit()
self.smoothing_entry.editingFinished.connect(self.set_smoothing_entry)
self.noise_std_label = QLabel("Noise standard deviation")
self.noise_std_entry = QLineEdit()
self.noise_std_entry.editingFinished.connect(self.set_noise_entry)
self.window_label = QLabel("Window (optional)")
self.window_entry = QLineEdit()
self.window_entry.editingFinished.connect(self.set_window_entry)
self.fast_checkbox = QCheckBox("Fast")
self.fast_checkbox.clicked.connect(self.set_fast_checkbox)
self.sigma_checkbox = QCheckBox("Estimate sigma")
self.sigma_checkbox.clicked.connect(self.set_sigma_checkbox)
self.help_button = QPushButton("Help")
self.help_button.clicked.connect(self.help_button_pressed)
self.delete_button = QPushButton("Delete reco dir")
self.delete_button.clicked.connect(self.delete_button_pressed)
self.dry_button = QPushButton("Dry run")
self.dry_button.clicked.connect(self.dry_button_pressed)
self.apply_button = QPushButton("Apply filter")
self.apply_button.clicked.connect(self.apply_button_pressed)
# self.apply_button.setStyleSheet("color:royalblue; font-weight: bold;")
self.set_layout()
def set_layout(self):
layout = QGridLayout()
layout.addWidget(self.apply_to_reco_checkbox, 0, 0, 1, 1)
layout.addWidget(self.input_dir_button, 1, 0, 1, 2)
layout.addWidget(self.select_img_button, 1, 2, 1, 2)
layout.addWidget(self.input_dir_entry, 2, 0, 1, 4)
layout.addWidget(self.output_dir_button, 3, 0, 1, 2)
layout.addWidget(self.save_bigtif_checkbox, 3, 2, 1, 2, Qt.AlignCenter)
layout.addWidget(self.output_dir_entry, 4, 0, 1, 4)
layout.addWidget(self.similarity_radius_label, 5, 0, 1, 2)
layout.addWidget(self.similarity_radius_entry, 5, 2, 1, 2)
layout.addWidget(self.patch_radius_label, 6, 0, 1, 2)
layout.addWidget(self.patch_radius_entry, 6, 2, 1, 2)
layout.addWidget(self.smoothing_label, 7, 0, 1, 2)
layout.addWidget(self.smoothing_entry, 7, 2, 1, 2)
layout.addWidget(self.noise_std_label, 8, 0, 1, 2)
layout.addWidget(self.noise_std_entry, 8, 2, 1, 2)
layout.addWidget(self.window_label, 9, 0, 1, 2)
layout.addWidget(self.window_entry, 9, 2, 1, 2)
layout.addWidget(self.fast_checkbox, 10, 0, 1, 2, Qt.AlignCenter)
layout.addWidget(self.sigma_checkbox, 10, 2, 1, 2, Qt.AlignCenter)
layout.addWidget(self.help_button, 11, 0, 1, 1)
layout.addWidget(self.delete_button, 11, 1)
layout.addWidget(self.dry_button, 11, 2)
layout.addWidget(self.apply_button, 11, 3)
self.setLayout(layout)
def init_values(self):
self.apply_to_reco_checkbox.setChecked(False)
parameters.params['advanced_nlmdn_apply_after_reco'] = False
self.input_dir_entry.setText(os.getcwd())
parameters.params['advanced_nlmdn_input_dir'] = os.getcwd()
self.output_dir_entry.setText(os.getcwd() + '-nlmfilt')
parameters.params['advanced_nlmdn_output_dir'] = os.getcwd() + '-nlmfilt'
self.e_bigtif = False
parameters.params['advanced_nlmdn_save_bigtiff'] = False
self.similarity_radius_entry.setText("10")
self.patch_radius_entry.setText("3")
self.smoothing_entry.setText("0.0")
self.noise_std_entry.setText("0.0")
self.window_entry.setText("0.0")
self.fast_checkbox.setChecked(True)
self.e_fast = True
self.sigma_checkbox.setChecked(False)
self.e_sig = False
self.e_dryrun = False
def set_values_from_params(self):
self.apply_to_reco_checkbox.setChecked(bool(parameters.params['advanced_nlmdn_apply_after_reco']))
self.input_dir_entry.setText(str(parameters.params['advanced_nlmdn_input_dir']))
self.output_dir_entry.setText(str(parameters.params['advanced_nlmdn_output_dir']))
self.save_bigtif_checkbox.setChecked(bool(parameters.params['advanced_nlmdn_save_bigtiff']))
self.similarity_radius_entry.setText(str(parameters.params['advanced_nlmdn_sim_search_radius']))
self.patch_radius_entry.setText(str(parameters.params['advanced_nlmdn_patch_radius']))
self.smoothing_entry.setText(str(parameters.params['advanced_nlmdn_smoothing_control']))
self.noise_std_entry.setText(str(parameters.params['advanced_nlmdn_noise_std']))
self.window_entry.setText(str(parameters.params['advanced_nlmdn_window']))
self.fast_checkbox.setChecked(bool(parameters.params['advanced_nlmdn_fast']))
self.sigma_checkbox.setChecked(bool(parameters.params['advanced_nlmdn_estimate_sigma']))
def set_apply_to_reco(self):
LOG.debug(
"Apply NLMDN to reconstructed slices checkbox: "
+ str(self.apply_to_reco_checkbox.isChecked())
)
parameters.params['advanced_nlmdn_apply_after_reco'] = bool(
self.apply_to_reco_checkbox.isChecked()
)
if self.apply_to_reco_checkbox.isChecked():
self.input_dir_button.setDisabled(True)
self.select_img_button.setDisabled(True)
self.input_dir_entry.setDisabled(True)
self.dry_button.setDisabled(True)
self.apply_button.setDisabled(True)
self.output_dir_button.setDisabled(True)
self.output_dir_entry.setDisabled(True)
elif not self.apply_to_reco_checkbox.isChecked():
self.input_dir_button.setDisabled(False)
self.select_img_button.setDisabled(False)
self.input_dir_entry.setDisabled(False)
self.dry_button.setDisabled(False)
self.apply_button.setDisabled(False)
self.output_dir_button.setDisabled(False)
self.output_dir_entry.setDisabled(False)
def set_indir_button(self):
"""
Saves directory specified by user in file-dialog for input tomographic data
"""
LOG.debug("Select input directory pressed")
dir_explore = QFileDialog(self)
directory = dir_explore.getExistingDirectory()
self.input_dir_entry.setText(directory)
parameters.params['advanced_nlmdn_input_dir'] = directory
self.output_dir_entry.setText(directory + "-nlmfilt")
parameters.params['advanced_nlmdn_output_dir'] = directory + "-nlmfilt"
parameters.params['advanced_nlmdn_input_is_file'] = False
def set_indir_entry(self):
LOG.debug("Indir entry: " + str(self.input_dir_entry.text()))
parameters.params['advanced_nlmdn_input_dir'] = str(self.input_dir_entry.text())
def select_image(self):
LOG.debug("Select one image button pressed")
options = QFileDialog.Options()
file_path, _ = QFileDialog.getOpenFileName(
self, "Open .tif Image File", "", "Tiff Files (*.tif *.tiff)", options=options
)
if file_path:
img_name, img_ext = os.path.splitext(file_path)
tmp = img_name + "-nlmfilt-%05i" + img_ext
self.input_dir_entry.setText(file_path)
self.output_dir_entry.setText(tmp)
parameters.params['advanced_nlmdn_input_dir'] = file_path
parameters.params['advanced_nlmdn_output_dir'] = tmp
parameters.params['advanced_nlmdn_input_is_file'] = True
def set_outdir_button(self):
LOG.debug("Select output directory pressed")
dir_explore = QFileDialog(self)
directory = dir_explore.getExistingDirectory()
self.output_dir_entry.setText(directory)
parameters.params['advanced_nlmdn_output_dir'] = directory
def set_save_bigtif(self):
LOG.debug("Save bigtif checkbox: " + str(self.save_bigtif_checkbox.isChecked()))
parameters.params['advanced_nlmdn_save_bigtiff'] = bool(self.save_bigtif_checkbox.isChecked())
def set_outdir_entry(self):
LOG.debug("Outdir entry: " + str(self.output_dir_entry.text()))
parameters.params['advanced_nlmdn_output_dir'] = str(self.output_dir_entry.text())
def set_rad_sim_entry(self):
LOG.debug("Radius for similarity: " + str(self.similarity_radius_entry.text()))
parameters.params['advanced_nlmdn_sim_search_radius'] = str(self.similarity_radius_entry.text())
def set_rad_patch_entry(self):
LOG.debug("Radius of patches: " + str(self.patch_radius_entry.text()))
parameters.params['advanced_nlmdn_patch_radius'] = str(self.patch_radius_entry.text())
def set_smoothing_entry(self):
LOG.debug("Smoothing control: " + str(self.smoothing_entry.text()))
parameters.params['advanced_nlmdn_smoothing_control'] = str(self.smoothing_entry.text())
def set_noise_entry(self):
LOG.debug("Noise std: " + str(self.noise_std_entry.text()))
parameters.params['advanced_nlmdn_noise_std'] = str(self.noise_std_entry.text())
def set_window_entry(self):
LOG.debug("Window: " + str(self.window_entry.text()))
parameters.params['advanced_nlmdn_window'] = str(self.window_entry.text())
def set_fast_checkbox(self):
LOG.debug("Fast: " + str(self.fast_checkbox.isChecked()))
parameters.params['advanced_nlmdn_fast'] = bool(self.fast_checkbox.isChecked())
def set_sigma_checkbox(self):
LOG.debug("Estimate sigma: " + str(self.sigma_checkbox.isChecked()))
parameters.params['advanced_nlmdn_estimate_sigma'] = bool(self.sigma_checkbox.isChecked())
def help_button_pressed(self):
LOG.debug("Help Button Pressed")
h = ""
h += 'Note4: set to "flats" if "flats2" exist but you need to ignore them; \n'
h += "SerG, BMIT CLS, Dec. 2020."
QMessageBox.information(self, "Help", h)
def delete_button_pressed(self):
LOG.debug("Delete Reco Button Pressed")
"""
Deletes the directory that contains reconstructed data
"""
LOG.debug("DELETE")
msg = "Delete directory with reconstructed data?"
dialog = QMessageBox.warning(self, "Warning: data can be lost", msg, QMessageBox.Yes | QMessageBox.No)
if dialog == QMessageBox.Yes:
if os.path.exists(str(parameters.params['advanced_nlmdn_output_dir'])):
LOG.debug("YES")
if parameters.params['advanced_nlmdn_output_dir'] == parameters.params['advanced_nlmdn_input_dir']:
LOG.debug("Cannot delete: output directory is the same as input")
else:
rmtree(parameters.params['advanced_nlmdn_output_dir'])
LOG.debug("Directory with reconstructed data was removed")
else:
LOG.debug("Directory does not exist")
else:
LOG.debug("NO")
def dry_button_pressed(self):
LOG.debug("Dry Run Button Pressed")
parameters.params['advanced_nlmdn_dry_run'] = True
self.apply_button_pressed()
parameters.params['advanced_nlmdn_dry_run'] = False
def apply_button_pressed(self):
LOG.debug("Apply Filter Button Pressed")
args = tk_args(parameters.params['advanced_nlmdn_apply_after_reco'],
parameters.params['advanced_nlmdn_input_dir'],
parameters.params['advanced_nlmdn_input_is_file'],
parameters.params['advanced_nlmdn_output_dir'],
parameters.params['advanced_nlmdn_save_bigtiff'],
parameters.params['advanced_nlmdn_sim_search_radius'],
parameters.params['advanced_nlmdn_patch_radius'],
parameters.params['advanced_nlmdn_smoothing_control'],
parameters.params['advanced_nlmdn_noise_std'],
parameters.params['advanced_nlmdn_window'],
parameters.params['advanced_nlmdn_fast'],
parameters.params['advanced_nlmdn_estimate_sigma'],
parameters.params['advanced_nlmdn_dry_run'])
#LOG.debug(args.args)
if os.path.exists(args.outdir) and not args.dryrun:
title_text = "Warning: files can be overwritten"
text1 = "Output directory exists. Files can be overwritten. Proceed?"
dialog = QMessageBox.warning(self, title_text, text1, QMessageBox.Yes | QMessageBox.No)
if dialog == QMessageBox.Yes:
main_tk(args)
QMessageBox.information(self, "Finished", "Finished")
else:
main_tk(args)
QMessageBox.information(self, "Finished", "Finished")
class tk_args:
def __init__(
self,
e_apply_after_reco,
e_indir,
e_input_is_file,
e_outdir,
e_bigtif,
e_r,
e_dx,
e_h,
e_sig,
e_w,
e_fast,
e_autosig,
e_dryrun,
):
self.args = {}
# PATHS
self.args["apply_after_reco"] = str(e_apply_after_reco)
setattr(self, "apply_after_reco", self.args["apply_after_reco"])
self.args["indir"] = str(e_indir)
setattr(self, "indir", self.args["indir"])
self.args["input_is_file"] = e_input_is_file
setattr(self, "input_is_file", self.args["input_is_file"])
self.args["outdir"] = str(e_outdir)
setattr(self, "outdir", self.args["outdir"])
# ALG PARAMS - MAIN
self.args["search_r"] = int(e_r)
setattr(self, "search_r", self.args["search_r"])
self.args["patch_r"] = int(e_dx)
setattr(self, "patch_r", self.args["patch_r"])
self.args["h"] = float(e_h)
setattr(self, "h", self.args["h"])
self.args["sig"] = float(e_sig)
setattr(self, "sig", self.args["sig"])
# ALG PARAMS - optional
self.args["w"] = float(e_w)
setattr(self, "w", self.args["w"])
self.args["fast"] = bool(e_fast)
setattr(self, "fast", self.args["fast"])
self.args["autosig"] = bool(e_autosig)
setattr(self, "autosig", self.args["autosig"])
# Misc
# self.args['inplace'] = bool(e_inplace.get())
# setattr(self, 'inplace', self.args['inplace'])
self.args["bigtif"] = bool(e_bigtif)
setattr(self, "bigtif", self.args["bigtif"])
self.args["dryrun"] = bool(e_dryrun)
setattr(self, "dryrun", self.args["dryrun"])
tofu-0.12.0/tofu/ez/GUI/Advanced/optimization.py 0000664 0000000 0000000 00000007556 14237137211 0021411 0 ustar 00root root 0000000 0000000 import logging
from PyQt5.QtWidgets import QGridLayout, QLabel, QGroupBox, QLineEdit, QCheckBox
import tofu.ez.params as parameters
LOG = logging.getLogger(__name__)
class OptimizationGroup(QGroupBox):
"""
Optimization settings
"""
def __init__(self):
super().__init__()
self.setTitle("Optimization Settings")
self.setStyleSheet("QGroupBox {color: orange;}")
self.verbose_switch = QCheckBox("Enable verbose console output")
self.verbose_switch.stateChanged.connect(self.set_verbose_switch)
self.slice_memory_label = QLabel("Slice memory coefficient")
self.slice_memory_entry = QLineEdit()
tmpstr="Fraction of VRAM which will be used to store images \n" \
"Reserve ~2 GB of VRAM for computation \n" \
"Decrease the coefficient if you have very large data and start getting errors"
self.slice_memory_entry.setToolTip(tmpstr)
self.slice_memory_label.setToolTip(tmpstr)
self.slice_memory_entry.editingFinished.connect(self.set_slice)
self.num_GPU_label = QLabel("Number of GPUs")
self.num_GPU_entry = QLineEdit()
self.num_GPU_entry.editingFinished.connect(self.set_num_gpu)
self.slices_per_device_label = QLabel("Slices per device")
self.slices_per_device_entry = QLineEdit()
self.slices_per_device_entry.editingFinished.connect(self.set_slices_per_device)
self.set_layout()
def set_layout(self):
layout = QGridLayout()
layout.addWidget(self.verbose_switch, 0, 0)
gpu_group = QGroupBox("GPU optimization")
gpu_group.setCheckable(True)
gpu_group.setChecked(False)
gpu_layout = QGridLayout()
gpu_layout.addWidget(self.slice_memory_label, 0, 0)
gpu_layout.addWidget(self.slice_memory_entry, 0, 1)
gpu_layout.addWidget(self.num_GPU_label, 1, 0)
gpu_layout.addWidget(self.num_GPU_entry, 1, 1)
gpu_layout.addWidget(self.slices_per_device_label, 2, 0)
gpu_layout.addWidget(self.slices_per_device_entry, 2, 1)
gpu_group.setLayout(gpu_layout)
layout.addWidget(gpu_group, 1, 0)
self.setLayout(layout)
def init_values(self):
self.verbose_switch.setChecked(False)
parameters.params['advanced_optimize_verbose_console'] = False
parameters.params['advanced_optimize_slice_mem_coeff'] = 0.7
self.slice_memory_entry.setText(
str(parameters.params['advanced_optimize_slice_mem_coeff']))
self.num_GPU_entry.setText("")
parameters.params['advanced_optimize_num_gpus'] = ""
self.slices_per_device_entry.setText("")
parameters.params['advanced_optimize_slices_per_device'] = ""
def set_values_from_params(self):
self.verbose_switch.setChecked(bool(parameters.params['advanced_optimize_verbose_console']))
self.slice_memory_entry.setText(str(parameters.params['advanced_optimize_slice_mem_coeff']))
self.num_GPU_entry.setText(str(parameters.params['advanced_optimize_num_gpus']))
self.slices_per_device_entry.setText(str(parameters.params['advanced_optimize_slices_per_device']))
def set_verbose_switch(self):
LOG.debug("Verbose: " + str(self.verbose_switch.isChecked()))
parameters.params['advanced_optimize_verbose_console'] = bool(self.verbose_switch.isChecked())
def set_slice(self):
LOG.debug(self.slice_memory_entry.text())
parameters.params['advanced_optimize_slice_mem_coeff'] = str(self.slice_memory_entry.text())
def set_num_gpu(self):
LOG.debug(self.num_GPU_entry.text())
parameters.params['advanced_optimize_num_gpus'] = str(self.num_GPU_entry.text())
def set_slices_per_device(self):
LOG.debug(self.slices_per_device_entry.text())
parameters.params['advanced_optimize_slices_per_device'] = str(self.slices_per_device_entry.text()) tofu-0.12.0/tofu/ez/GUI/Main/ 0000775 0000000 0000000 00000000000 14237137211 0015453 5 ustar 00root root 0000000 0000000 tofu-0.12.0/tofu/ez/GUI/Main/__init__.py 0000664 0000000 0000000 00000000000 14237137211 0017552 0 ustar 00root root 0000000 0000000 tofu-0.12.0/tofu/ez/GUI/Main/centre_of_rotation.py 0000664 0000000 0000000 00000017604 14237137211 0021720 0 ustar 00root root 0000000 0000000 import logging
from PyQt5.QtWidgets import QGridLayout, QLabel, QRadioButton, QGroupBox, QLineEdit
import tofu.ez.params as parameters
LOG = logging.getLogger(__name__)
class CentreOfRotationGroup(QGroupBox):
"""
Centre of Rotation settings
"""
def __init__(self):
super().__init__()
self.setTitle("Centre of Rotation")
self.setStyleSheet("QGroupBox {color: green;}")
self.auto_correlate_rButton = QRadioButton()
self.auto_correlate_rButton.setText("Auto: Correlate first/last projections")
self.auto_correlate_rButton.clicked.connect(self.set_rButton)
self.auto_minimize_rButton = QRadioButton()
self.auto_minimize_rButton.setText("Auto: Minimize STD of a slice")
self.auto_minimize_rButton.setToolTip(
"Reconstructed patches are saved \nin your-temporary-data-folder\\axis-search"
)
self.auto_minimize_rButton.clicked.connect(self.set_rButton)
self.define_axis_rButton = QRadioButton()
self.define_axis_rButton.setText("Define rotation axis manually")
self.define_axis_rButton.clicked.connect(self.set_rButton)
self.search_rotation_label = QLabel()
self.search_rotation_label.setText("Search rotation axis in [start, stop, step] interval")
self.search_rotation_entry = QLineEdit()
self.search_rotation_entry.editingFinished.connect(self.set_search_rotation)
self.search_rotation_entry.setStyleSheet("background-color:white")
self.search_in_slice_label = QLabel()
self.search_in_slice_label.setText("Search in slice from row number")
self.search_in_slice_entry = QLineEdit()
self.search_in_slice_entry.editingFinished.connect(self.set_search_slice)
self.search_in_slice_entry.setStyleSheet("background-color:white")
self.size_of_recon_label = QLabel()
self.size_of_recon_label.setText("Size of reconstructed patch [pixel]")
self.size_of_recon_entry = QLineEdit()
self.size_of_recon_entry.editingFinished.connect(self.set_size_of_reco)
self.size_of_recon_entry.setStyleSheet("background-color:white")
self.axis_col_label = QLabel()
self.axis_col_label.setText("Axis is in column No [pixel]")
self.axis_col_entry = QLineEdit()
self.axis_col_entry.editingFinished.connect(self.set_axis_col)
self.axis_col_entry.setStyleSheet("background-color:white")
self.inc_axis_label = QLabel()
self.inc_axis_label.setText("Increment axis every reconstruction")
self.inc_axis_entry = QLineEdit()
self.inc_axis_entry.editingFinished.connect(self.set_axis_inc)
self.inc_axis_entry.setStyleSheet("background-color:white")
self.image_midpoint_rButton = QRadioButton()
self.image_midpoint_rButton.setText("Use image midpoint (for half-acquisition)")
self.image_midpoint_rButton.clicked.connect(self.set_rButton)
# TODO Used for proper spacing - should be a better way
self.blank_label = QLabel(" ")
self.blank_label2 = QLabel(" ")
self.set_layout()
def set_layout(self):
layout = QGridLayout()
layout.addWidget(self.auto_correlate_rButton, 0, 0)
layout.addWidget(self.blank_label, 0, 1)
layout.addWidget(self.blank_label2, 0, 2)
layout.addWidget(self.auto_minimize_rButton, 1, 0)
layout.addWidget(self.search_rotation_label, 2, 0)
layout.addWidget(self.search_rotation_entry, 2, 1, 1, 2)
layout.addWidget(self.search_in_slice_label, 3, 0)
layout.addWidget(self.search_in_slice_entry, 3, 1, 1, 2)
layout.addWidget(self.size_of_recon_label, 4, 0)
layout.addWidget(self.size_of_recon_entry, 4, 1, 1, 2)
layout.addWidget(self.define_axis_rButton, 5, 0)
layout.addWidget(self.axis_col_label, 6, 0)
layout.addWidget(self.axis_col_entry, 6, 1, 1, 2)
layout.addWidget(self.inc_axis_label, 7, 0)
layout.addWidget(self.inc_axis_entry, 7, 1, 1, 2)
layout.addWidget(self.image_midpoint_rButton, 8, 0)
self.setLayout(layout)
def init_values(self):
self.auto_correlate_rButton.setChecked(True)
self.auto_minimize_rButton.setChecked(False)
self.define_axis_rButton.setChecked(False)
self.image_midpoint_rButton.setChecked(False)
self.set_rButton()
self.search_rotation_entry.setText("1010,1030,0.5")
self.search_in_slice_entry.setText("100")
self.size_of_recon_entry.setText("256")
self.axis_col_entry.setText("0.0")
self.inc_axis_entry.setText("0.0")
# self.bypass_checkbox.setChecked(False)
def set_values_from_params(self):
self.set_rButton_from_params()
self.search_rotation_entry.setText(str(parameters.params['main_cor_axis_search_interval']))
self.search_in_slice_entry.setText(str(parameters.params['main_cor_search_row_start']))
self.size_of_recon_entry.setText(str(parameters.params['main_cor_recon_patch_size']))
self.axis_col_entry.setText(str(parameters.params['main_cor_axis_column']))
self.inc_axis_entry.setText(str(parameters.params['main_cor_axis_increment_step']))
def set_rButton(self):
if self.auto_correlate_rButton.isChecked():
LOG.debug("Auto Correlate")
parameters.params['main_cor_axis_search_method'] = 1
elif self.auto_minimize_rButton.isChecked():
LOG.debug("Auto Minimize")
parameters.params['main_cor_axis_search_method'] = 2
elif self.define_axis_rButton.isChecked():
LOG.debug("Define axis")
parameters.params['main_cor_axis_search_method'] = 3
elif self.image_midpoint_rButton.isChecked():
LOG.debug("Use image midpoint")
parameters.params['main_cor_axis_search_method'] = 4
def set_rButton_from_params(self):
if parameters.params['main_cor_axis_search_method'] == 1:
self.auto_correlate_rButton.setChecked(True)
self.auto_minimize_rButton.setChecked(False)
self.define_axis_rButton.setChecked(False)
self.image_midpoint_rButton.setChecked(False)
elif parameters.params['main_cor_axis_search_method'] == 2:
self.auto_correlate_rButton.setChecked(False)
self.auto_minimize_rButton.setChecked(True)
self.define_axis_rButton.setChecked(False)
self.image_midpoint_rButton.setChecked(False)
elif parameters.params['main_cor_axis_search_method'] == 3:
self.auto_correlate_rButton.setChecked(False)
self.auto_minimize_rButton.setChecked(False)
self.define_axis_rButton.setChecked(True)
self.image_midpoint_rButton.setChecked(False)
elif parameters.params['main_cor_axis_search_method'] == 4:
self.auto_correlate_rButton.setChecked(False)
self.auto_minimize_rButton.setChecked(False)
self.define_axis_rButton.setChecked(False)
self.image_midpoint_rButton.setChecked(True)
def set_search_rotation(self):
LOG.debug(self.search_rotation_entry.text())
parameters.params['main_cor_axis_search_interval'] = str(self.search_rotation_entry.text())
def set_search_slice(self):
LOG.debug(self.search_in_slice_entry.text())
parameters.params['main_cor_search_row_start'] = str(self.search_in_slice_entry.text())
def set_size_of_reco(self):
LOG.debug(self.size_of_recon_entry.text())
parameters.params['main_cor_recon_patch_size'] = str(self.size_of_recon_entry.text())
def set_axis_col(self):
LOG.debug(self.axis_col_entry.text())
parameters.params['main_cor_axis_column'] = str(self.axis_col_entry.text())
def set_axis_inc(self):
LOG.debug(self.inc_axis_entry.text())
parameters.params['main_cor_axis_increment_step'] = str(self.inc_axis_entry.text()) tofu-0.12.0/tofu/ez/GUI/Main/config.py 0000664 0000000 0000000 00000161623 14237137211 0017303 0 ustar 00root root 0000000 0000000 import os
import logging
import numpy as np
from shutil import rmtree
from PyQt5.QtWidgets import (
QMessageBox,
QFileDialog,
QCheckBox,
QPushButton,
QGridLayout,
QLabel,
QGroupBox,
QLineEdit,
)
from PyQt5.QtCore import QCoreApplication
from PyQt5.QtCore import pyqtSignal
from PyQt5.QtCore import Qt
from tofu.ez.main import execute_reconstruction, clean_tmp_dirs
from tofu.ez.yaml_in_out import Yaml_IO
from tofu.ez.GUI.message_dialog import warning_message
import tofu.ez.params as parameters
#TODO Get rid of the old args structure and store all parameters
# like tofu does
LOG = logging.getLogger(__name__)
class ConfigGroup(QGroupBox):
"""
Setup and configuration settings
"""
# Used to send signal to ezufo_launcher when settings are imported https://stackoverflow.com/questions/2970312/pyqt4-qtcore-pyqtsignal-object-has-no-attribute-connect
signal_update_vals_from_params = pyqtSignal(dict)
# Used to send signal when reconstruction is done
signal_reco_done = pyqtSignal(dict)
def __init__(self):
super().__init__()
self.setTitle("Configuration")
self.setStyleSheet("QGroupBox {color: purple;}")
self.yaml_io = Yaml_IO()
# Select input directory
self.input_dir_select = QPushButton("Select input directory (or paste abs. path)")
self.input_dir_select.setStyleSheet("background-color:lightgrey; font: 12pt;")
self.input_dir_entry = QLineEdit()
self.input_dir_entry.editingFinished.connect(self.set_input_dir)
self.input_dir_select.pressed.connect(self.select_input_dir)
# Save .params checkbox
self.save_params_checkbox = QCheckBox("Save args in .params file")
self.save_params_checkbox.stateChanged.connect(self.set_save_args)
# Select output directory
self.output_dir_select = QPushButton()
self.output_dir_select.setText("Select output directory (or paste abs. path)")
self.output_dir_select.setStyleSheet("background-color:lightgrey; font: 12pt;")
self.output_dir_entry = QLineEdit()
self.output_dir_entry.editingFinished.connect(self.set_output_dir)
self.output_dir_select.pressed.connect(self.select_output_dir)
# Save in separate files or in one huge tiff file
self.bigtiff_checkbox = QCheckBox()
self.bigtiff_checkbox.setText("Save slices in multipage tiffs")
self.bigtiff_checkbox.setToolTip(
"Will save images in bigtiff containers. \n"
"Note that some temporary data is always saved in bigtiffs.\n"
"Use bio-formats importer plugin for imagej or fiji to open the bigtiffs."
)
self.bigtiff_checkbox.stateChanged.connect(self.set_big_tiff)
# Crop in the reconstruction plane
self.preproc_checkbox = QCheckBox()
self.preproc_checkbox.setText("Preprocess with a generic ufo-launch pipeline, f.i.")
self.preproc_checkbox.setToolTip(
"Selected ufo filters will be applied to each "
"image before reconstruction begins. \n"
'To print the list of filters use "ufo-query -l" command. \n'
'Parameters of each filter can be seen with "ufo-query -p filtername".'
)
self.preproc_checkbox.stateChanged.connect(self.set_preproc)
self.preproc_entry = QLineEdit()
self.preproc_entry.editingFinished.connect(self.set_preproc_entry)
# Names of directories with flats/darks/projections frames
self.e_DIRTYP = ["darks", "flats", "tomo", "flats2"]
self.dir_name_label = QLabel()
self.dir_name_label.setText("Name of flats/darks/tomo subdirectories in each CT data set")
self.darks_entry = QLineEdit()
self.darks_entry.editingFinished.connect(self.set_darks)
self.flats_entry = QLineEdit()
self.flats_entry.editingFinished.connect(self.set_flats)
self.tomo_entry = QLineEdit()
self.tomo_entry.editingFinished.connect(self.set_tomo)
self.flats2_entry = QLineEdit()
self.flats2_entry.editingFinished.connect(self.set_flats2)
# Select flats/darks/flats2 for use in multiple reconstructions
self.use_common_flats_darks_checkbox = QCheckBox()
self.use_common_flats_darks_checkbox.setText(
"Use common flats/darks across multiple experiments"
)
self.use_common_flats_darks_checkbox.stateChanged.connect(self.set_flats_darks_checkbox)
self.select_darks_button = QPushButton("Select path to darks (or paste abs. path)")
self.select_darks_button.setToolTip("Background detector noise")
self.select_darks_button.clicked.connect(self.select_darks_button_pressed)
self.select_flats_button = QPushButton("Select path to flats (or paste abs. path)")
self.select_flats_button.setToolTip("Images without sample in the beam")
self.select_flats_button.clicked.connect(self.select_flats_button_pressed)
self.select_flats2_button = QPushButton("Select path to flats2 (or paste abs. path)")
self.select_flats2_button.setToolTip(
"If selected, it will be assumed that flats were \n"
"acquired before projections while flats2 after \n"
"and interpolation will be used to compute intensity of flat image \n"
"for each projection between flats and flats2"
)
self.select_flats2_button.clicked.connect(self.select_flats2_button_pressed)
self.darks_absolute_entry = QLineEdit()
self.darks_absolute_entry.setText("Absolute path to darks")
self.darks_absolute_entry.editingFinished.connect(self.set_common_darks)
self.flats_absolute_entry = QLineEdit()
self.flats_absolute_entry.setText("Absolute path to flats")
self.flats_absolute_entry.editingFinished.connect(self.set_common_flats)
self.use_flats2_checkbox = QCheckBox("Use common flats2")
self.use_flats2_checkbox.clicked.connect(self.set_use_flats2)
self.flats2_absolute_entry = QLineEdit()
self.flats2_absolute_entry.editingFinished.connect(self.set_common_flats2)
self.flats2_absolute_entry.setText("Absolute path to flats2")
# Select temporary directory
self.temp_dir_select = QPushButton()
self.temp_dir_select.setText("Select temporary directory (or paste abs. path)")
self.temp_dir_select.setToolTip(
"Temporary data will be saved there.\n"
"note that the size of temporary data can exceed 300 GB in some cases."
)
self.temp_dir_select.pressed.connect(self.select_temp_dir)
self.temp_dir_select.setStyleSheet("background-color:lightgrey; font: 12pt;")
self.temp_dir_entry = QLineEdit()
self.temp_dir_entry.editingFinished.connect(self.set_temp_dir)
# Keep temp data selection
self.keep_tmp_data_checkbox = QCheckBox()
self.keep_tmp_data_checkbox.setText("Keep all temp data till the end of reconstruction")
self.keep_tmp_data_checkbox.setToolTip(
"Useful option to inspect how images change at each step"
)
self.keep_tmp_data_checkbox.stateChanged.connect(self.set_keep_tmp_data)
# IMPORT SETTINGS FROM FILE
self.open_settings_file = QPushButton()
self.open_settings_file.setText("Import parameters from file")
self.open_settings_file.setStyleSheet("background-color:lightgrey; font: 12pt;")
self.open_settings_file.pressed.connect(self.import_settings_button_pressed)
# EXPORT SETTINGS TO FILE
self.save_settings_file = QPushButton()
self.save_settings_file.setText("Export parameters to file")
self.save_settings_file.setStyleSheet("background-color:lightgrey; font: 12pt;")
self.save_settings_file.pressed.connect(self.export_settings_button_pressed)
# QUIT
self.quit_button = QPushButton()
self.quit_button.setText("Quit")
self.quit_button.setStyleSheet("background-color:lightgrey; font: 13pt; font-weight: bold;")
self.quit_button.clicked.connect(self.quit_button_pressed)
# HELP
self.help_button = QPushButton()
self.help_button.setText("Help")
self.help_button.setStyleSheet("background-color:lightgrey; font: 13pt; font-weight: bold")
self.help_button.clicked.connect(self.help_button_pressed)
# DELETE
self.delete_reco_dir_button = QPushButton()
self.delete_reco_dir_button.setText("Delete reco dir")
self.delete_reco_dir_button.setStyleSheet(
"background-color:lightgrey; font: 13pt; font-weight: bold"
)
self.delete_reco_dir_button.clicked.connect(self.delete_button_pressed)
# DRY RUN
self.dry_run_button = QPushButton()
self.dry_run_button.setText("Dry run")
self.dry_run_button.setStyleSheet(
"background-color:lightgrey; font: 13pt; font-weight: bold"
)
self.dry_run_button.clicked.connect(self.dryrun_button_pressed)
# RECONSTRUCT
self.reco_button = QPushButton()
self.reco_button.setText("Reconstruct")
self.reco_button.setStyleSheet(
"background-color:lightgrey;color:royalblue; font: 14pt; font-weight: bold;"
)
self.reco_button.clicked.connect(self.reco_button_pressed)
# OPEN IMAGE AFTER RECONSTRUCT
self.open_image_after_reco_checkbox = QCheckBox()
self.open_image_after_reco_checkbox.setText(
"Load images and open viewer after reconstruction"
)
self.open_image_after_reco_checkbox.clicked.connect(self.set_open_image_after_reco)
self.set_layout()
def set_layout(self):
"""
Sets the layout of buttons, labels, etc. for config group
"""
layout = QGridLayout()
checkbox_groupbox = QGroupBox()
checkbox_layout = QGridLayout()
checkbox_layout.addWidget(self.save_params_checkbox, 0, 0)
checkbox_layout.addWidget(self.bigtiff_checkbox, 1, 0)
checkbox_layout.addWidget(self.open_image_after_reco_checkbox, 2, 0)
checkbox_layout.addWidget(self.keep_tmp_data_checkbox, 3, 0)
checkbox_groupbox.setLayout(checkbox_layout)
layout.addWidget(checkbox_groupbox, 0, 4, 4, 1)
layout.addWidget(self.input_dir_select, 0, 0)
layout.addWidget(self.input_dir_entry, 0, 1, 1, 3)
layout.addWidget(self.output_dir_select, 1, 0)
layout.addWidget(self.output_dir_entry, 1, 1, 1, 3)
layout.addWidget(self.temp_dir_select, 2, 0)
layout.addWidget(self.temp_dir_entry, 2, 1, 1, 3)
layout.addWidget(self.preproc_checkbox, 3, 0)
layout.addWidget(self.preproc_entry, 3, 1, 1, 3)
fdt_groupbox = QGroupBox()
fdt_layout = QGridLayout()
fdt_layout.addWidget(self.dir_name_label, 0, 0)
fdt_layout.addWidget(self.darks_entry, 0, 1)
fdt_layout.addWidget(self.flats_entry, 0, 2)
fdt_layout.addWidget(self.tomo_entry, 0, 3)
fdt_layout.addWidget(self.flats2_entry, 0, 4)
fdt_layout.addWidget(self.use_common_flats_darks_checkbox, 1, 0)
fdt_layout.addWidget(self.select_darks_button, 1, 1)
fdt_layout.addWidget(self.select_flats_button, 1, 2)
fdt_layout.addWidget(self.select_flats2_button, 1, 4)
fdt_layout.addWidget(self.darks_absolute_entry, 2, 1)
fdt_layout.addWidget(self.flats_absolute_entry, 2, 2)
fdt_layout.addWidget(self.use_flats2_checkbox, 2, 3, Qt.AlignRight)
fdt_layout.addWidget(self.flats2_absolute_entry, 2, 4)
fdt_groupbox.setLayout(fdt_layout)
layout.addWidget(fdt_groupbox, 4, 0, 1, 5)
layout.addWidget(self.open_settings_file, 5, 0, 1, 3)
layout.addWidget(self.save_settings_file, 5, 3, 1, 2)
layout.addWidget(self.quit_button, 6, 0)
layout.addWidget(self.help_button, 6, 1)
layout.addWidget(self.delete_reco_dir_button, 6, 2)
layout.addWidget(self.dry_run_button, 6, 3)
layout.addWidget(self.reco_button, 6, 4)
self.setLayout(layout)
def init_values(self):
"""
Sets the initial default values of config group
"""
# If we're on a computer with access to network
indir = os.path.expanduser('~')#"/beamlinedata/BMIT/projects/"
if os.path.isdir(indir):
self.input_dir_entry.setText(indir)
outdir = os.path.abspath(indir + "/rec")
self.output_dir_entry.setText(outdir)
# Otherwise use this as default
self.save_params_checkbox.setChecked(True)
parameters.params['main_config_save_params'] = True
parameters.params['main_config_save_multipage_tiff'] = False
self.preproc_checkbox.setChecked(False)
self.set_preproc()
parameters.params['main_config_preprocess'] = False
self.preproc_entry.setText("remove-outliers size=3 threshold=500 sign=1")
self.darks_entry.setText("darks")
self.flats_entry.setText("flats")
self.tomo_entry.setText("tomo")
self.flats2_entry.setText("flats2")
self.use_common_flats_darks_checkbox.setChecked(False)
self.darks_absolute_entry.setText("Absolute path to darks")
self.flats_absolute_entry.setText("Absolute path to flats")
self.use_common_flats_darks_checkbox.setChecked(False)
self.flats2_absolute_entry.setText("Absolute path to flats2")
self.temp_dir_entry.setText(os.path.join(os.path.expanduser('~'),"tmp-ezufo"))
self.keep_tmp_data_checkbox.setChecked(False)
parameters.params['main_config_keep_temp'] = False
self.set_temp_dir()
self.dry_run_button.setChecked(False)
parameters.params['main_config_dry_run'] = False
parameters.params['main_config_open_viewer'] = False
self.open_image_after_reco_checkbox.setChecked(False)
def set_values_from_params(self):
"""
Updates displayed values for config group
Called when .yaml file of params is loaded
"""
self.input_dir_entry.setText(parameters.params['main_config_input_dir'])
self.save_params_checkbox.setChecked(parameters.params['main_config_save_params'])
self.output_dir_entry.setText(parameters.params['main_config_output_dir'])
self.bigtiff_checkbox.setChecked(parameters.params['main_config_save_multipage_tiff'])
self.preproc_checkbox.setChecked(parameters.params['main_config_preprocess'])
self.preproc_entry.setText(parameters.params['main_config_preprocess_command'])
self.darks_entry.setText(parameters.params['main_config_darks_dir_name'])
self.flats_entry.setText(parameters.params['main_config_flats_dir_name'])
self.tomo_entry.setText(parameters.params['main_config_tomo_dir_name'])
self.flats2_entry.setText(parameters.params['main_config_flats2_dir_name'])
self.temp_dir_entry.setText(parameters.params['main_config_temp_dir'])
self.keep_tmp_data_checkbox.setChecked(parameters.params['main_config_keep_temp'])
self.dry_run_button.setChecked(parameters.params['main_config_dry_run'])
self.open_image_after_reco_checkbox.setChecked(parameters.params['main_config_open_viewer'])
self.use_common_flats_darks_checkbox.setChecked(parameters.params['main_config_common_flats_darks'])
self.darks_absolute_entry.setText(parameters.params['main_config_darks_path'])
self.flats_absolute_entry.setText(parameters.params['main_config_flats_path'])
self.use_flats2_checkbox.setChecked(parameters.params['main_config_flats2_checkbox'])
self.flats2_absolute_entry.setText(parameters.params['main_config_flats2_path'])
def select_input_dir(self):
"""
Saves directory specified by user in file-dialog for input tomographic data
"""
dir_explore = QFileDialog(self)
dir = dir_explore.getExistingDirectory(directory=self.input_dir_entry.text())
self.input_dir_entry.setText(dir)
parameters.params['main_config_input_dir'] = dir
def set_input_dir(self):
LOG.debug(str(self.input_dir_entry.text()))
parameters.params['main_config_input_dir'] = str(self.input_dir_entry.text())
def select_output_dir(self):
dir_explore = QFileDialog(self)
dir = dir_explore.getExistingDirectory(directory=self.output_dir_entry.text())
self.output_dir_entry.setText(dir)
parameters.params['main_config_output_dir'] = dir
def set_output_dir(self):
LOG.debug(str(self.output_dir_entry.text()))
parameters.params['main_config_output_dir'] = str(self.output_dir_entry.text())
def set_big_tiff(self):
LOG.debug("Bigtiff: " + str(self.bigtiff_checkbox.isChecked()))
parameters.params['main_config_save_multipage_tiff'] = bool(self.bigtiff_checkbox.isChecked())
def set_preproc(self):
LOG.debug("Preproc: " + str(self.preproc_checkbox.isChecked()))
parameters.params['main_config_preprocess'] = bool(self.preproc_checkbox.isChecked())
def set_preproc_entry(self):
LOG.debug(self.preproc_entry.text())
parameters.params['main_config_preprocess_command'] = str(self.preproc_entry.text())
def set_open_image_after_reco(self):
LOG.debug(
"Switch to Image Viewer After Reco: "
+ str(self.open_image_after_reco_checkbox.isChecked())
)
parameters.params['main_config_open_viewer'] = bool(self.open_image_after_reco_checkbox.isChecked())
def set_darks(self):
LOG.debug(self.darks_entry.text())
self.e_DIRTYP[0] = str(self.darks_entry.text())
parameters.params['main_config_darks_dir_name'] = str(self.darks_entry.text())
def set_flats(self):
LOG.debug(self.flats_entry.text())
self.e_DIRTYP[1] = str(self.flats_entry.text())
parameters.params['main_config_flats_dir_name'] = str(self.flats_entry.text())
def set_tomo(self):
LOG.debug(self.tomo_entry.text())
self.e_DIRTYP[2] = str(self.tomo_entry.text())
parameters.params['main_config_tomo_dir_name'] = str(self.tomo_entry.text())
def set_flats2(self):
LOG.debug(self.flats2_entry.text())
self.e_DIRTYP[3] = str(self.flats2_entry.text())
parameters.params['main_config_flats2_dir_name'] = str(self.flats2_entry.text())
def set_fdt_names(self):
self.set_darks()
self.set_flats()
self.set_flats2()
self.set_tomo()
def set_flats_darks_checkbox(self):
LOG.debug(
"Use same flats/darks across multiple experiments: "
+ str(self.use_common_flats_darks_checkbox.isChecked())
)
parameters.params['main_config_common_flats_darks'] = bool(
self.use_common_flats_darks_checkbox.isChecked()
)
def select_darks_button_pressed(self):
LOG.debug("Select path to darks pressed")
dir_explore = QFileDialog(self)
directory = dir_explore.getExistingDirectory(directory=parameters.params['main_config_input_dir'])
self.darks_absolute_entry.setText(directory)
parameters.params['main_config_darks_path'] = directory
def select_flats_button_pressed(self):
LOG.debug("Select path to flats pressed")
dir_explore = QFileDialog(self)
directory = dir_explore.getExistingDirectory(directory=parameters.params['main_config_input_dir'])
self.flats_absolute_entry.setText(directory)
parameters.params['main_config_flats_path'] = directory
def select_flats2_button_pressed(self):
LOG.debug("Select path to flats2 pressed")
dir_explore = QFileDialog(self)
directory = dir_explore.getExistingDirectory(directory=parameters.params['main_config_input_dir'])
self.flats2_absolute_entry.setText(directory)
parameters.params['main_config_flats2_path'] = directory
def set_common_darks(self):
LOG.debug("Common darks path: " + str(self.darks_absolute_entry.text()))
parameters.params['main_config_darks_path'] = str(self.darks_absolute_entry.text())
def set_common_flats(self):
LOG.debug("Common flats path: " + str(self.flats_absolute_entry.text()))
parameters.params['main_config_flats_path'] = str(self.flats_absolute_entry.text())
def set_use_flats2(self):
LOG.debug("Use common flats2 checkbox: " + str(self.use_flats2_checkbox.isChecked()))
parameters.params['main_config_flats2_checkbox'] = bool(self.use_flats2_checkbox.isChecked())
def set_common_flats2(self):
LOG.debug("Common flats2 path: " + str(self.flats2_absolute_entry.text()))
parameters.params['main_config_flats2_path'] = str(self.flats2_absolute_entry.text())
def select_temp_dir(self):
dir_explore = QFileDialog(self)
tmp_dir = dir_explore.getExistingDirectory(directory=self.temp_dir_entry.text())
self.temp_dir_entry.setText(tmp_dir)
def set_temp_dir(self):
LOG.debug(str(self.temp_dir_entry.text()))
parameters.params['main_config_temp_dir'] = str(self.temp_dir_entry.text())
def set_keep_tmp_data(self):
LOG.debug("Keep tmp: " + str(self.keep_tmp_data_checkbox.isChecked()))
parameters.params['main_config_keep_temp'] = bool(self.keep_tmp_data_checkbox.isChecked())
def quit_button_pressed(self):
"""
Displays confirmation dialog and cleans temporary directories
"""
LOG.debug("QUIT")
reply = QMessageBox.question(
self,
"Quit",
"Are you sure you want to quit?",
QMessageBox.Yes | QMessageBox.No,
QMessageBox.No,
)
if reply == QMessageBox.Yes:
# remove all directories with projections
clean_tmp_dirs(parameters.params['main_config_temp_dir'], self.get_fdt_names())
# remove axis-search dir too
tmp = os.path.join(parameters.params['main_config_temp_dir'], 'axis-search')
QCoreApplication.instance().quit()
else:
pass
def help_button_pressed(self):
"""
Displays pop-up help information
"""
LOG.debug("HELP")
h = "This utility provides an interface to the ufo-kit software package.\n"
h += "Use it for batch processing and optimization of reconstruction parameters.\n"
h += "It creates a list of paths to all CT directories in the _input_ directory.\n"
h += "A CT directory is defined as directory with at least \n"
h += "_flats_, _darks_, _tomo_, and, optionally, _flats2_ subdirectories, \n"
h += "which are not empty and contain only *.tif files. Names of CT\n"
h += "directories are compared with the directory tree in the _output_ directory.\n"
h += (
"(Note: relative directory tree in _input_ is preserved when writing results to the"
" _output_.)\n"
)
h += (
"Those CT sets will be reconstructed, whose names are not yet in the _output_"
" directory."
)
h += "Program will create an array of ufo/tofu commands according to defined parameters \n"
h += (
"and then execute them sequentially. These commands can be also printed on the"
" screen.\n"
)
h += "Note2: if you bin in preprocess the center of rotation will change a lot; \n"
h += 'Note4: set to "flats" if "flats2" exist but you need to ignore them; \n'
h += (
"Created by Sergei Gasilov, BMIT CLS, Dec. 2018.\n Extended by Iain Emslie, Summer"
" 2021."
)
QMessageBox.information(self, "Help", h)
def delete_button_pressed(self):
"""
Deletes the directory that contains reconstructed data
"""
LOG.debug("DELETE")
msg = "Delete directory with reconstructed data?"
dialog = QMessageBox.warning(
self, "Warning: data can be lost", msg, QMessageBox.Yes | QMessageBox.No
)
if dialog == QMessageBox.Yes:
if os.path.exists(str(parameters.params['main_config_output_dir'])):
LOG.debug("YES")
if parameters.params['main_config_output_dir'] == parameters.params['main_config_input_dir']:
LOG.debug("Cannot delete: output directory is the same as input")
else:
try:
rmtree(parameters.params['main_config_output_dir'])
except:
warning_message('Error while deleting directory')
LOG.debug("Directory with reconstructed data was removed")
else:
LOG.debug("Directory does not exist")
else:
LOG.debug("NO")
def dryrun_button_pressed(self):
"""
Sets the dry-run parameter for Tofu to True
and calls reconstruction
"""
LOG.debug("DRY")
parameters.params['main_config_dry_run'] = str(True)
self.reco_button_pressed()
parameters.params['main_config_dry_run'] = bool(False)
def set_save_args(self):
LOG.debug("Save args: " + str(self.save_params_checkbox.isChecked()))
parameters.params['main_config_save_params'] = bool(self.save_params_checkbox.isChecked())
def export_settings_button_pressed(self):
"""
Saves currently displayed GUI settings
to an external .yaml file specified by user
"""
LOG.debug("Save settings pressed")
options = QFileDialog.Options()
fileName, _ = QFileDialog.getSaveFileName(
self,
"QFileDialog.getSaveFileName()",
"",
"YAML Files (*.yaml);; All Files (*)",
options=options,
)
if fileName:
LOG.debug("Export YAML Path: " + fileName)
file_extension = os.path.splitext(fileName)
if file_extension[-1] == "":
fileName = fileName + ".yaml"
# Create and write to YAML file based on given fileName
self.yaml_io.write_yaml(fileName, parameters.params)
def import_settings_button_pressed(self):
"""
Loads external settings from .yaml file specified by user
Signal is sent to enable updating of displayed GUI values
"""
LOG.debug("Import settings pressed")
options = QFileDialog.Options()
filePath, _ = QFileDialog.getOpenFileName(
self,
"QFileDialog.getOpenFileName()",
"",
"YAML Files (*.yaml);; All Files (*)",
options=options,
)
if filePath:
LOG.debug("Import YAML Path: " + filePath)
yaml_data = self.yaml_io.read_yaml(filePath)
parameters.params = dict(yaml_data)
self.signal_update_vals_from_params.emit(parameters.params)
def reco_button_pressed(self):
"""
Gets the settings set by the user in the GUI
These are then passed to execute_reconstruction
"""
LOG.debug("RECO")
LOG.debug(parameters.params)
self.run_reconstruction(parameters.params, batch_run=False)
def run_reconstruction(self, params, batch_run):
try:
self.validate_input()
args = tk_args(params['main_config_input_dir'],
params['main_config_temp_dir'],
params['main_config_output_dir'],
params['main_config_save_multipage_tiff'],
params['main_cor_axis_search_method'],
params['main_cor_axis_search_interval'],
params['main_cor_search_row_start'],
params['main_cor_recon_patch_size'],
params['main_cor_axis_column'],
params['main_cor_axis_increment_step'],
params['main_filters_remove_spots'],
params['main_filters_remove_spots_threshold'],
params['main_filters_remove_spots_blur_sigma'],
params['main_filters_ring_removal'],
params['main_filters_ring_removal_ufo_lpf'],
params['main_filters_ring_removal_ufo_lpf_1d_or_2d'],
params['main_filters_ring_removal_ufo_lpf_sigma_horizontal'],
params['main_filters_ring_removal_ufo_lpf_sigma_vertical'],
params['main_filters_ring_removal_sarepy_window_size'],
params['main_filters_ring_removal_sarepy_wide'],
params['main_filters_ring_removal_sarepy_window'],
params['main_filters_ring_removal_sarepy_SNR'],
params['main_pr_phase_retrieval'],
params['main_pr_photon_energy'],
params['main_pr_pixel_size'],
params['main_pr_detector_distance'],
params['main_pr_delta_beta_ratio'],
params['main_region_select_rows'],
params['main_region_first_row'],
params['main_region_number_rows'],
params['main_region_nth_row'],
params['main_region_clip_histogram'],
params['main_region_bit_depth'],
params['main_region_histogram_min'],
params['main_region_histogram_max'],
params['main_config_preprocess'],
params['main_config_preprocess_command'],
params['main_region_rotate_volume_clock'],
params['main_region_crop_slices'],
params['main_region_crop_x'],
params['main_region_crop_width'],
params['main_region_crop_y'],
params['main_region_crop_height'],
params['main_config_dry_run'],
params['main_config_save_params'],
params['main_config_keep_temp'],
params['advanced_ffc_sinFFC'],
params['advanced_ffc_method'],
params['advanced_ffc_eigen_pco_reps'],
params['advanced_ffc_eigen_pco_downsample'],
params['advanced_ffc_downsample'],
params['main_config_common_flats_darks'],
params['main_config_darks_path'],
params['main_config_flats_path'],
params['main_config_flats2_checkbox'],
params['main_config_flats2_path'],
# NLMDN Parameters
params['advanced_nlmdn_apply_after_reco'],
params['advanced_nlmdn_input_dir'],
params['advanced_nlmdn_input_is_file'],
params['advanced_nlmdn_output_dir'],
params['advanced_nlmdn_save_bigtiff'],
params['advanced_nlmdn_sim_search_radius'],
params['advanced_nlmdn_patch_radius'],
params['advanced_nlmdn_smoothing_control'],
params['advanced_nlmdn_noise_std'],
params['advanced_nlmdn_window'],
params['advanced_nlmdn_fast'],
params['advanced_nlmdn_estimate_sigma'],
params['advanced_nlmdn_dry_run'],
# Advanced Parameters
params['advanced_advtofu_extended_settings'],
params['advanced_advtofu_lamino_angle'],
params['advanced_adv_tofu_z_axis_rotation'],
params['advanced_advtofu_center_position_z'],
params['advanced_advtofu_y_axis_rotation'],
params['advanced_advtofu_aux_ffc_dark_scale'],
params['advanced_advtofu_aux_ffc_flat_scale'],
params['advanced_optimize_verbose_console'],
params['advanced_optimize_slice_mem_coeff'],
params['advanced_optimize_num_gpus'],
params['advanced_optimize_slices_per_device']
)
execute_reconstruction(args, self.get_fdt_names())
if batch_run is False:
msg = "Done. See output in terminal for details."
QMessageBox.information(self, "Finished", msg)
if not params['main_config_dry_run']:
self.signal_reco_done.emit(params)
except InvalidInputError as err:
msg = ""
err_arg = err.args
msg += err.args[0]
QMessageBox.information(self, "Invalid Input Error", msg)
# NEED TO DETERMINE VALID RANGES
# ALSO CHECK TYPES SOMEHOW
def validate_input(self):
"""
Determines whether user-input values are valid
"""
# Search rotation: main_cor_axis_search_interval
# Search in slice: main_cor_search_row_start
if int(parameters.params['main_cor_search_row_start']) < 0:
raise InvalidInputError("Value out of range for: Search in slice from row number")
# Size of reconstructed: main_cor_recon_patch_size
if int(parameters.params['main_cor_recon_patch_size']) < 0:
raise InvalidInputError("Value out of range for: Size of reconstructed patch [pixel]")
# Axis is in column No: main_cor_axis_column
if float(parameters.params['main_cor_axis_column']) < 0:
raise InvalidInputError("Value out of range for: Axis is in column No [pixel]")
# Increment axis: main_cor_axis_increment_step
if float(parameters.params['main_cor_axis_increment_step']) < 0:
raise InvalidInputError("Value out of range for: Increment axis every reconstruction")
# Threshold: main_filters_remove_spots_threshold
if int(parameters.params['main_filters_remove_spots_threshold']) < 0:
raise InvalidInputError("Value out of range for: Threshold (prominence of the spot) [counts]")
# Spot blur: main_filters_remove_spots_blur_sigma
if int(parameters.params['main_filters_remove_spots_blur_sigma']) < 0:
raise InvalidInputError("Value out of range for: Spot blur. sigma [pixels]")
# Sigma: e_sig_hor
if int(parameters.params['main_filters_ring_removal_ufo_lpf_sigma_horizontal']) < 0:
raise InvalidInputError("Value out of range for: ufo ring-removal sigma horizontal")
# Sigma: e_sig_ver
if int(parameters.params['main_filters_ring_removal_ufo_lpf_sigma_vertical']) < 0:
raise InvalidInputError("Value out of range for: ufo ring-removal sigma vertical")
# Window size: main_filters_ring_removal_sarepy_window_size
if int(parameters.params['main_filters_ring_removal_sarepy_window_size']) < 0:
raise InvalidInputError("Value out of range for: window size")
# Wind: main_filters_ring_removal_sarepy_window
if int(parameters.params['main_filters_ring_removal_sarepy_window']) < 0:
raise InvalidInputError("Value out of range for: wind")
# SNR: main_filters_ring_removal_sarepy_SNR
if int(parameters.params['main_filters_ring_removal_sarepy_SNR']) < 0:
raise InvalidInputError("Value out of range for: SNR")
# Photon energy: main_pr_photon_energy
if float(parameters.params['main_pr_photon_energy']) < 0:
raise InvalidInputError("Value out of range for: Photon energy [keV]")
# Pixel size: main_pr_pixel_size
if float(parameters.params['main_pr_pixel_size']) < 0:
raise InvalidInputError("Value out of range for: Pixel size [micron]")
# Sample detector distance: main_pr_detector_distance
if float(parameters.params['main_pr_detector_distance']) < 0:
raise InvalidInputError("Value out of range for: Sample-detector distance [m]")
# Delta/beta ratio: main_pr_delta_beta_ratio
if int(parameters.params['main_pr_delta_beta_ratio']) < 0:
raise InvalidInputError("Value out of range for: Delta/beta ratio: (try default if unsure)")
# First row in projections: main_region_first_row
if int(parameters.params['main_region_first_row']) < 0:
raise InvalidInputError("Value out of range for: First row in projections")
# Number of rows: main_region_number_rows
if int(parameters.params['main_region_number_rows']) < 0:
raise InvalidInputError("Value out of range for: Number of rows (ROI height)")
# Reconstruct every Nth row: main_region_nth_row
if int(parameters.params['main_region_nth_row']) < 0:
raise InvalidInputError("Value out of range for: Reconstruct every Nth row")
# Can be negative when 16-bit selected
# Min value: main_region_histogram_min
#if float(parameters.params['main_region_histogram_min']) < 0:
# raise InvalidInputError("Value out of range for: Min value in 32-bit histogram")
# Max value: main_region_histogram_max
if float(parameters.params['main_region_histogram_max']) < 0:
raise InvalidInputError("Value out of range for: Max value in 32-bit histogram")
# x: main_region_crop_x
if int(parameters.params['main_region_crop_x']) < 0:
raise InvalidInputError("Value out of range for: Crop slices: x")
# width: main_region_crop_width
if int(parameters.params['main_region_crop_width']) < 0:
raise InvalidInputError("Value out of range for: Crop slices: width")
# y: main_region_crop_y
if int(parameters.params['main_region_crop_y']) < 0:
raise InvalidInputError("Value out of range for: Crop slices: y")
# height: main_region_crop_height
if int(parameters.params['main_region_crop_height']) < 0:
raise InvalidInputError("Value out of range for: Crop slices: height")
if int(parameters.params['advanced_ffc_eigen_pco_reps']) < 0:
raise InvalidInputError("Value out of range for: Flat Field Correction: Eigen PCO Repetitions")
if int(parameters.params['advanced_ffc_eigen_pco_downsample']) < 0:
raise InvalidInputError("Value out of range for: Flat Field Correction: Eigen PCO Downsample")
if int(parameters.params['advanced_ffc_downsample']) < 0:
raise InvalidInputError("Value out of range for: Flat Field Correction: Downsample")
# Can be negative value
# Optional: rotate volume: main_region_rotate_volume_clock
#if float(parameters.params['main_region_rotate_volume_clock']) < 0:
# raise InvalidInputError("Value out of range for: Optional: rotate volume clock by [deg]")
#TODO ADD CHECKING NLMDN SETTINGS
#TODO ADD CHECKING FOR ADVANCED SETTINGS
'''
if int(parameters.params['e_adv_rotation_range']) < 0:
raise InvalidInputError("Advanced: Rotation range must be greater than or equal to zero")
if float(parameters.params['advanced_advtofu_lamino_angle']) < 0 or float(parameters.params['advanced_advtofu_lamino_angle']) > 90:
raise InvalidInputError("Advanced: Lamino angle must be a float between 0 and 90")
if float(parameters.params['advanced_optimize_slice_mem_coeff']) < 0 or float(parameters.params['advanced_optimize_slice_mem_coeff']) > 1:
raise InvalidInputError("Advanced: Slice memory coefficient must be between 0 and 1")
'''
def get_fdt_names(self):
DIRTYP = []
for i in self.e_DIRTYP:
DIRTYP.append(i)
LOG.debug("Result of get_fdt_names")
LOG.debug(DIRTYP)
return DIRTYP
class tk_args():
def __init__(self, main_config_input_dir, main_config_temp_dir, main_config_output_dir, main_config_save_multipage_tiff,
main_cor_axis_search_method, main_cor_axis_search_interval, main_cor_search_row_start,
main_cor_recon_patch_size, main_cor_axis_column, main_cor_axis_increment_step,
main_filters_remove_spots, main_filters_remove_spots_threshold, main_filters_remove_spots_blur_sigma,
main_filters_ring_removal, main_filters_ring_removal_ufo_lpf, main_filters_ring_removal_ufo_lpf_1d_or_2d,
main_filters_ring_removal_ufo_lpf_sigma_horizontal, main_filters_ring_removal_ufo_lpf_sigma_vertical,
main_filters_ring_removal_sarepy_window_size, main_filters_ring_removal_sarepy_wide, main_filters_ring_removal_sarepy_window, main_filters_ring_removal_sarepy_SNR,
main_pr_phase_retrieval, main_pr_photon_energy, main_pr_pixel_size, main_pr_detector_distance,
main_pr_delta_beta_ratio, main_region_select_rows, main_region_first_row, main_region_number_rows, main_region_nth_row, main_region_clip_histogram, main_region_bit_depth, main_region_histogram_min, main_region_histogram_max,
main_config_preprocess, main_config_preprocess_command, main_region_rotate_volume_clock, main_region_crop_slices, main_region_crop_x, main_region_crop_width, main_region_crop_y, main_region_crop_height,
main_config_dry_run, main_config_save_params, main_config_keep_temp, advanced_ffc_sinFFC, advanced_ffc_method, advanced_ffc_eigen_pco_reps,
advanced_ffc_eigen_pco_downsample, advanced_ffc_downsample, main_config_common_flats_darks,
main_config_darks_path, main_config_flats_path, main_config_flats2_checkbox, main_config_flats2_path,
advanced_nlmdn_apply_after_reco, advanced_nlmdn_input_dir, advanced_nlmdn_input_is_file, advanced_nlmdn_output_dir, advanced_nlmdn_save_bigtiff,
advanced_nlmdn_sim_search_radius, advanced_nlmdn_patch_radius, advanced_nlmdn_smoothing_control, advanced_nlmdn_noise_std,
advanced_nlmdn_window, advanced_nlmdn_fast, advanced_nlmdn_estimate_sigma, advanced_nlmdn_dry_run,
advanced_advtofu_extended_settings,
advanced_advtofu_lamino_angle, advanced_adv_tofu_z_axis_rotation, advanced_advtofu_center_position_z, advanced_advtofu_y_axis_rotation,
advanced_advtofu_aux_ffc_dark_scale, advanced_advtofu_aux_ffc_flat_scale,
advanced_optimize_verbose_console, advanced_optimize_slice_mem_coeff, advanced_optimize_num_gpus, advanced_optimize_slices_per_device):
self.args={}
# PATHS
self.args['main_config_input_dir']=str(main_config_input_dir)
setattr(self,'main_config_input_dir',self.args['main_config_input_dir'])
self.args['main_config_output_dir']=str(main_config_output_dir)
setattr(self,'main_config_output_dir',self.args['main_config_output_dir'])
self.args['main_config_temp_dir']=str(main_config_temp_dir)
setattr(self,'main_config_temp_dir',self.args['main_config_temp_dir'])
self.args['main_config_save_multipage_tiff']=bool(main_config_save_multipage_tiff)
setattr(self,'main_config_save_multipage_tiff',self.args['main_config_save_multipage_tiff'])
# center of rotation parameters
self.args['main_cor_axis_search_method']=int(main_cor_axis_search_method)
setattr(self,'main_cor_axis_search_method',self.args['main_cor_axis_search_method'])
self.args['main_cor_axis_search_interval']=str(main_cor_axis_search_interval)
setattr(self,'main_cor_axis_search_interval',self.args['main_cor_axis_search_interval'])
self.args['main_cor_recon_patch_size']=int(main_cor_recon_patch_size)
setattr(self,'main_cor_recon_patch_size',self.args['main_cor_recon_patch_size'])
self.args['main_cor_search_row_start']=int(main_cor_search_row_start)
setattr(self,'main_cor_search_row_start',self.args['main_cor_search_row_start'])
self.args['main_cor_axis_column']=float(main_cor_axis_column)
setattr(self,'main_cor_axis_column',self.args['main_cor_axis_column'])
self.args['main_cor_axis_increment_step']=float(main_cor_axis_increment_step)
setattr(self,'main_cor_axis_increment_step',self.args['main_cor_axis_increment_step'])
#ring removal
self.args['main_filters_remove_spots']=bool(main_filters_remove_spots)
setattr(self,'main_filters_remove_spots',self.args['main_filters_remove_spots'])
self.args['main_filters_remove_spots_threshold']=int(main_filters_remove_spots_threshold)
setattr(self,'main_filters_remove_spots_threshold', self.args['main_filters_remove_spots_threshold'])
self.args['main_filters_remove_spots_blur_sigma']=int(main_filters_remove_spots_blur_sigma)
setattr(self,'main_filters_remove_spots_blur_sigma',self.args['main_filters_remove_spots_blur_sigma'])
self.args['main_filters_ring_removal']=bool(main_filters_ring_removal)
setattr(self,'main_filters_ring_removal',self.args['main_filters_ring_removal'])
self.args['main_filters_ring_removal_ufo_lpf'] = bool(main_filters_ring_removal_ufo_lpf)
setattr(self, 'main_filters_ring_removal_ufo_lpf', self.args['main_filters_ring_removal_ufo_lpf'])
self.args['main_filters_ring_removal_ufo_lpf_1d_or_2d'] = bool(main_filters_ring_removal_ufo_lpf_1d_or_2d)
setattr(self, 'main_filters_ring_removal_ufo_lpf_1d_or_2d', self.args['main_filters_ring_removal_ufo_lpf_1d_or_2d'])
self.args['main_filters_ring_removal_ufo_lpf_sigma_horizontal'] = int(main_filters_ring_removal_ufo_lpf_sigma_horizontal)
setattr(self,'main_filters_ring_removal_ufo_lpf_sigma_horizontal',self.args['main_filters_ring_removal_ufo_lpf_sigma_horizontal'])
self.args['main_filters_ring_removal_ufo_lpf_sigma_vertical'] = int(main_filters_ring_removal_ufo_lpf_sigma_vertical)
setattr(self, 'main_filters_ring_removal_ufo_lpf_sigma_vertical', self.args['main_filters_ring_removal_ufo_lpf_sigma_vertical'])
self.args['main_filters_ring_removal_sarepy_window_size'] = int(main_filters_ring_removal_sarepy_window_size)
setattr(self, 'main_filters_ring_removal_sarepy_window_size', self.args['main_filters_ring_removal_sarepy_window_size'])
self.args['main_filters_ring_removal_sarepy_wide'] = bool(main_filters_ring_removal_sarepy_wide)
setattr(self, 'main_filters_ring_removal_sarepy_wide', self.args['main_filters_ring_removal_sarepy_wide'])
self.args['main_filters_ring_removal_sarepy_window'] = int(main_filters_ring_removal_sarepy_window)
setattr(self, 'main_filters_ring_removal_sarepy_window', self.args['main_filters_ring_removal_sarepy_window'])
self.args['main_filters_ring_removal_sarepy_SNR'] = int(main_filters_ring_removal_sarepy_SNR)
setattr(self, 'main_filters_ring_removal_sarepy_SNR', self.args['main_filters_ring_removal_sarepy_SNR'])
# phase retrieval
self.args['main_pr_phase_retrieval'] = bool(main_pr_phase_retrieval)
setattr(self, 'main_pr_phase_retrieval', self.args['main_pr_phase_retrieval'])
self.args['main_pr_photon_energy']=float(main_pr_photon_energy)
setattr(self,'main_pr_photon_energy',self.args['main_pr_photon_energy'])
self.args['main_pr_pixel_size']=float(main_pr_pixel_size)*1e-6
setattr(self,'main_pr_pixel_size',self.args['main_pr_pixel_size'])
self.args['main_pr_detector_distance']=float(main_pr_detector_distance)
setattr(self,'main_pr_detector_distance',self.args['main_pr_detector_distance'])
self.args['main_pr_delta_beta_ratio']=np.log10(float(main_pr_delta_beta_ratio))
setattr(self,'main_pr_delta_beta_ratio',self.args['main_pr_delta_beta_ratio'])
# Crop vertically
self.args['main_region_select_rows']=bool(main_region_select_rows)
setattr(self,'main_region_select_rows',self.args['main_region_select_rows'])
self.args['main_region_first_row']=int(main_region_first_row)
setattr(self,'main_region_first_row',self.args['main_region_first_row'])
self.args['main_region_number_rows']=int(main_region_number_rows)
setattr(self,'main_region_number_rows',self.args['main_region_number_rows'])
self.args['main_region_nth_row']=int(main_region_nth_row)
setattr(self,'main_region_nth_row',self.args['main_region_nth_row'])
# conv to 8 bit
self.args['main_region_clip_histogram']=bool(main_region_clip_histogram)
setattr(self,'main_region_clip_histogram',self.args['main_region_clip_histogram'])
self.args['main_region_bit_depth']=int(main_region_bit_depth)
setattr(self,'main_region_bit_depth',self.args['main_region_bit_depth'])
self.args['main_region_histogram_min']=float(main_region_histogram_min)
setattr(self,'main_region_histogram_min',self.args['main_region_histogram_min'])
self.args['main_region_histogram_max']=float(main_region_histogram_max)
setattr(self,'main_region_histogram_max',self.args['main_region_histogram_max'])
# preprocessing attributes
self.args['main_config_preprocess']=bool(main_config_preprocess)
setattr(self,'main_config_preprocess',self.args['main_config_preprocess'])
self.args['main_config_preprocess_command']=main_config_preprocess_command
setattr(self,'main_config_preprocess_command',self.args['main_config_preprocess_command'])
# ROI in slice
self.args['main_region_crop_slices']=bool(main_region_crop_slices)
setattr(self,'main_region_crop_slices',self.args['main_region_crop_slices'])
self.args['main_region_crop_x']=int(main_region_crop_x)
setattr(self,'main_region_crop_x',self.args['main_region_crop_x'])
self.args['main_region_crop_width']=int(main_region_crop_width)
setattr(self,'main_region_crop_width',self.args['main_region_crop_width'])
self.args['main_region_crop_y']=int(main_region_crop_y)
setattr(self,'main_region_crop_y',self.args['main_region_crop_y'])
self.args['main_region_crop_height']=int(main_region_crop_height)
setattr(self,'main_region_crop_height',self.args['main_region_crop_height'])
# Optional FBP params
self.args['main_region_rotate_volume_clock']= float(main_region_rotate_volume_clock)
setattr(self,'main_region_rotate_volume_clock',self.args['main_region_rotate_volume_clock'])
# misc settings
self.args['main_config_dry_run']=bool(main_config_dry_run)
setattr(self,'main_config_dry_run',self.args['main_config_dry_run'])
self.args['main_config_save_params']=bool(main_config_save_params)
setattr(self,'main_config_save_params',self.args['main_config_save_params'])
self.args['main_config_keep_temp']=bool(main_config_keep_temp)
setattr(self,'main_config_keep_temp',self.args['main_config_keep_temp'])
#sinFFC settings
self.args['advanced_ffc_sinFFC']=bool(advanced_ffc_sinFFC)
setattr(self,'advanced_ffc_sinFFC', self.args['advanced_ffc_sinFFC'])
self.args['advanced_ffc_method'] = str(advanced_ffc_method)
setattr(self, 'advanced_ffc_method', self.args['advanced_ffc_method'])
self.args['advanced_ffc_eigen_pco_reps']=int(advanced_ffc_eigen_pco_reps)
setattr(self, 'advanced_ffc_eigen_pco_reps', self.args['advanced_ffc_eigen_pco_reps'])
self.args['advanced_ffc_eigen_pco_downsample'] = int(advanced_ffc_eigen_pco_downsample)
setattr(self, 'advanced_ffc_eigen_pco_downsample', self.args['advanced_ffc_eigen_pco_downsample'])
self.args['advanced_ffc_downsample'] = int(advanced_ffc_downsample)
setattr(self, 'advanced_ffc_downsample', self.args['advanced_ffc_downsample'])
#Settings for using flats/darks across multiple experiments
self.args['main_config_common_flats_darks'] = bool(main_config_common_flats_darks)
setattr(self, 'main_config_common_flats_darks', self.args['main_config_common_flats_darks'])
self.args['main_config_darks_path'] = str(main_config_darks_path)
setattr(self, 'main_config_darks_path', self.args['main_config_darks_path'])
self.args['main_config_flats_path'] = str(main_config_flats_path)
setattr(self, 'main_config_flats_path', self.args['main_config_flats_path'])
self.args['main_config_flats2_checkbox'] = bool(main_config_flats2_checkbox)
setattr(self, 'main_config_flats2_checkbox', self.args['main_config_flats2_checkbox'])
self.args['main_config_flats2_path'] = str(main_config_flats2_path)
setattr(self, 'main_config_flats2_path', self.args['main_config_flats2_path'])
#NLMDN Settings
self.args['advanced_nlmdn_apply_after_reco'] = bool(advanced_nlmdn_apply_after_reco)
setattr(self, 'advanced_nlmdn_apply_after_reco', self.args['advanced_nlmdn_apply_after_reco'])
self.args['advanced_nlmdn_input_dir'] = str(advanced_nlmdn_input_dir)
setattr(self, 'advanced_nlmdn_input_dir', self.args['advanced_nlmdn_input_dir'])
self.args['advanced_nlmdn_input_is_file'] = bool(advanced_nlmdn_input_is_file)
setattr(self, 'advanced_nlmdn_input_is_file', self.args['advanced_nlmdn_input_is_file'])
self.args['advanced_nlmdn_output_dir'] = str(advanced_nlmdn_output_dir)
setattr(self, 'advanced_nlmdn_output_dir', self.args['advanced_nlmdn_output_dir'])
self.args['advanced_nlmdn_save_bigtiff'] = bool(advanced_nlmdn_save_bigtiff)
setattr(self, 'advanced_nlmdn_save_bigtiff', self.args['advanced_nlmdn_save_bigtiff'])
self.args['advanced_nlmdn_sim_search_radius'] = str(advanced_nlmdn_sim_search_radius)
setattr(self, 'advanced_nlmdn_sim_search_radius', self.args['advanced_nlmdn_sim_search_radius'])
self.args['advanced_nlmdn_patch_radius'] = str(advanced_nlmdn_patch_radius)
setattr(self, 'advanced_nlmdn_patch_radius', self.args['advanced_nlmdn_patch_radius'])
self.args['advanced_nlmdn_smoothing_control'] = str(advanced_nlmdn_smoothing_control)
setattr(self, 'advanced_nlmdn_smoothing_control', self.args['advanced_nlmdn_smoothing_control'])
self.args['advanced_nlmdn_noise_std'] = str(advanced_nlmdn_noise_std)
setattr(self, 'advanced_nlmdn_noise_std', self.args['advanced_nlmdn_noise_std'])
self.args['advanced_nlmdn_window'] = str(advanced_nlmdn_window)
setattr(self, 'advanced_nlmdn_window', self.args['advanced_nlmdn_window'])
self.args['advanced_nlmdn_fast'] = bool(advanced_nlmdn_fast)
setattr(self, 'advanced_nlmdn_fast', self.args['advanced_nlmdn_fast'])
self.args['advanced_nlmdn_estimate_sigma'] = bool(advanced_nlmdn_estimate_sigma)
setattr(self, 'advanced_nlmdn_estimate_sigma', self.args['advanced_nlmdn_estimate_sigma'])
self.args['advanced_nlmdn_dry_run'] = bool(advanced_nlmdn_dry_run)
setattr(self, 'advanced_nlmdn_dry_run', self.args['advanced_nlmdn_dry_run'])
#Advanced Settings
self.args['advanced_advtofu_extended_settings'] = bool(advanced_advtofu_extended_settings)
setattr(self, 'advanced_advtofu_extended_settings', self.args['advanced_advtofu_extended_settings'])
self.args['advanced_advtofu_lamino_angle'] = str(advanced_advtofu_lamino_angle)
setattr(self, 'advanced_advtofu_lamino_angle', self.args['advanced_advtofu_lamino_angle'])
self.args['advanced_adv_tofu_z_axis_rotation'] = str(advanced_adv_tofu_z_axis_rotation)
setattr(self, 'advanced_adv_tofu_z_axis_rotation', self.args['advanced_adv_tofu_z_axis_rotation'])
self.args['advanced_advtofu_center_position_z'] = str(advanced_advtofu_center_position_z)
setattr(self, 'advanced_advtofu_center_position_z', self.args['advanced_advtofu_center_position_z'])
self.args['advanced_advtofu_y_axis_rotation'] = str(advanced_advtofu_y_axis_rotation)
setattr(self, 'advanced_advtofu_y_axis_rotation', self.args['advanced_advtofu_y_axis_rotation'])
self.args['advanced_advtofu_aux_ffc_dark_scale'] = str(advanced_advtofu_aux_ffc_dark_scale)
setattr(self, 'advanced_advtofu_aux_ffc_dark_scale', self.args['advanced_advtofu_aux_ffc_dark_scale'])
self.args['advanced_advtofu_aux_ffc_flat_scale'] = str(advanced_advtofu_aux_ffc_flat_scale)
setattr(self, 'advanced_advtofu_aux_ffc_flat_scale', self.args['advanced_advtofu_aux_ffc_flat_scale'])
#Optimization
self.args['advanced_optimize_verbose_console'] = bool(advanced_optimize_verbose_console)
setattr(self, 'advanced_optimize_verbose_console', self.args['advanced_optimize_verbose_console'])
self.args['advanced_optimize_slice_mem_coeff'] = str(advanced_optimize_slice_mem_coeff)
setattr(self, 'advanced_optimize_slice_mem_coeff', self.args['advanced_optimize_slice_mem_coeff'])
self.args['advanced_optimize_num_gpus'] = str(advanced_optimize_num_gpus)
setattr(self, 'advanced_optimize_num_gpus', self.args['advanced_optimize_num_gpus'])
self.args['advanced_optimize_slices_per_device'] = str(advanced_optimize_slices_per_device)
setattr(self, 'advanced_optimize_slices_per_device', self.args['advanced_optimize_slices_per_device'])
LOG.debug("Contents of arg dict: ")
LOG.debug(self.args.items())
class InvalidInputError(Exception):
"""
Error to be raised when input values from GUI are out of range or invalid
"""
tofu-0.12.0/tofu/ez/GUI/Main/filters.py 0000664 0000000 0000000 00000031177 14237137211 0017506 0 ustar 00root root 0000000 0000000 import logging
from PyQt5.QtWidgets import (
QButtonGroup,
QGridLayout,
QLabel,
QRadioButton,
QCheckBox,
QGroupBox,
QLineEdit,
)
from PyQt5.QtCore import Qt
import tofu.ez.params as parameters
LOG = logging.getLogger(__name__)
class FiltersGroup(QGroupBox):
"""
Filter settings
"""
def __init__(self):
super().__init__()
self.setTitle("Filters")
self.setStyleSheet("QGroupBox {color: orange;}")
self.remove_spots_checkBox = QCheckBox()
self.remove_spots_checkBox.setText("Remove large spots from projections")
self.remove_spots_checkBox.setToolTip(
"Efficiently suppresses very intense rings \n stemming from defects in scintillator"
)
self.remove_spots_checkBox.stateChanged.connect(self.set_remove_spots)
self.threshold_label = QLabel()
self.threshold_label.setText("Threshold (prominence of the spot) [counts]")
self.threshold_label.setToolTip(
"Outliers which will be considered as the part of the large spot"
)
self.threshold_entry = QLineEdit()
self.threshold_entry.editingFinished.connect(self.set_threshold)
self.spot_blur_label = QLabel()
self.spot_blur_label.setText("Spot blur. sigma [pixels]")
self.spot_blur_label.setToolTip(
"Regulates extent of the masked region around the detected outlier"
)
self.spot_blur_entry = QLineEdit()
self.spot_blur_entry.editingFinished.connect(self.set_spot_blur)
self.enable_RR_checkbox = QCheckBox()
self.enable_RR_checkbox.setText("Enable ring removal")
self.remove_spots_checkBox.setToolTip(
"To suppress ring artifacts"
" stemming from intensity fluctuations and detector nonlinearities"
)
self.enable_RR_checkbox.stateChanged.connect(self.set_ring_removal)
self.use_LPF_rButton = QRadioButton()
self.use_LPF_rButton.setText("Use ufo Fourier-transform based filter")
self.use_LPF_rButton.clicked.connect(self.select_rButton)
self.sarepy_rButton = QRadioButton()
self.sarepy_rButton.setText("Use sarepy sorting: ")
self.sarepy_rButton.clicked.connect(self.select_rButton)
self.sarepy_rButton.setToolTip(
"Non-FFT based algorithms from \n /Nghia T. Vo et al, Opt. Express 26, 28396 (2018)"
)
self.filter_rButton_group = QButtonGroup(self)
self.filter_rButton_group.addButton(self.use_LPF_rButton)
self.filter_rButton_group.addButton(self.sarepy_rButton)
self.one_dimens_rButton = QRadioButton()
self.one_dimens_rButton.setText("1D")
self.one_dimens_rButton.clicked.connect(self.select_dimens_rButton)
self.one_dimens_rButton.setToolTip("Only low-pass filter along the lines of sinogram")
self.two_dimens_rButton = QRadioButton()
self.two_dimens_rButton.setText("2D")
self.two_dimens_rButton.clicked.connect(self.select_dimens_rButton)
self.two_dimens_rButton.setToolTip(
"Low-pass filter along the lines and high-pass filter along the columns"
)
self.dimens_rButton_group = QButtonGroup(self)
self.dimens_rButton_group.addButton(self.one_dimens_rButton)
self.dimens_rButton_group.addButton(self.two_dimens_rButton)
self.sigma_horizontal_label = QLabel()
self.sigma_horizontal_label.setText("sigma horizontal")
self.sigma_horizontal_label.setToolTip(
"Width [pixels] of Gaussian-shaped low-pass filter in frequency domain"
)
self.sigma_horizontal_entry = QLineEdit()
self.sigma_horizontal_entry.editingFinished.connect(self.set_sigma_horizontal)
self.sigma_vertical_label = QLabel()
self.sigma_vertical_label.setText("sigma vertical")
self.sigma_vertical_label.setToolTip(
"Width [pixels] of Gaussian-shaped high-pass filter in frequency domain"
)
self.sigma_vertical_entry = QLineEdit()
self.sigma_vertical_entry.editingFinished.connect(self.set_sigma_vertical)
self.wind_size_label = QLabel()
self.wind_size_label.setText("window size")
self.wind_size_label.setToolTip("Window size in remove_stripe_based_sorting algorithm")
self.wind_size_entry = QLineEdit()
self.wind_size_entry.editingFinished.connect(self.set_window_size)
self.wind_size_entry.setToolTip("Typically in the range 31..51 ")
self.remove_wide_checkbox = QCheckBox()
self.remove_wide_checkbox.setText("Remove wide")
self.remove_wide_checkbox.setToolTip("Window size in remove_large_stripe algorithm")
self.remove_wide_checkbox.stateChanged.connect(self.set_remove_wide)
self.remove_wide_label = QLabel()
self.remove_wide_label.setText("window")
self.remove_wide_label.setToolTip("Typically in the range 51..131 ")
self.remove_wide_entry = QLineEdit()
self.remove_wide_entry.editingFinished.connect(self.set_wind)
self.SNR_label = QLabel()
self.SNR_label.setText("SNR")
self.SNR_label.setToolTip("SNR param in remove_large_stripe algorithm")
self.SNR_entry = QLineEdit()
self.SNR_entry.editingFinished.connect(self.set_SNR)
self.set_layout()
def set_layout(self):
layout = QGridLayout()
remove_spots_groupbox = QGroupBox()
remove_spots_layout = QGridLayout()
remove_spots_layout.addWidget(self.remove_spots_checkBox, 0, 0)
remove_spots_layout.addWidget(self.threshold_label, 1, 0)
remove_spots_layout.addWidget(self.threshold_entry, 1, 1, 1, 7)
remove_spots_layout.addWidget(self.spot_blur_label, 2, 0)
remove_spots_layout.addWidget(self.spot_blur_entry, 2, 1, 1, 7)
remove_spots_groupbox.setLayout(remove_spots_layout)
layout.addWidget(remove_spots_groupbox)
rr_groupbox = QGroupBox()
rr_layout = QGridLayout()
rr_layout.addWidget(self.enable_RR_checkbox, 3, 0)
rr_layout.addWidget(self.use_LPF_rButton, 4, 0)
rr_layout.addWidget(self.one_dimens_rButton, 4, 1)
rr_layout.addWidget(self.two_dimens_rButton, 4, 2)
rr_layout.addWidget(self.sigma_horizontal_label, 4, 3, Qt.AlignRight)
rr_layout.addWidget(self.sigma_horizontal_entry, 4, 4)
rr_layout.addWidget(self.sigma_vertical_label, 4, 5, Qt.AlignRight)
rr_layout.addWidget(self.sigma_vertical_entry, 4, 6)
rr_layout.addWidget(self.sarepy_rButton, 5, 0)
rr_layout.addWidget(self.wind_size_label, 5, 1)
rr_layout.addWidget(self.wind_size_entry, 5, 2)
rr_layout.addWidget(self.remove_wide_checkbox, 5, 3)
rr_layout.addWidget(self.remove_wide_label, 5, 4, Qt.AlignRight)
rr_layout.addWidget(self.remove_wide_entry, 5, 5)
rr_layout.addWidget(self.SNR_label, 5, 6)
rr_layout.addWidget(self.SNR_entry, 5, 7)
rr_groupbox.setLayout(rr_layout)
layout.addWidget(rr_groupbox, 3, 0)
self.setLayout(layout)
def init_values(self):
self.remove_wide_checkbox.setChecked(False)
self.set_remove_spots()
parameters.params['main_filters_remove_spots'] = False
self.threshold_entry.setText(
str(parameters.params['main_filters_remove_spots_threshold'])
)
self.spot_blur_entry.setText(
str(parameters.params['main_filters_remove_spots_blur_sigma'])
)
self.enable_RR_checkbox.setChecked(False)
self.set_ring_removal()
parameters.params['main_filters_ring_removal'] = False
self.use_LPF_rButton.setChecked(True)
self.select_rButton()
self.sarepy_rButton.setChecked(False)
self.two_dimens_rButton.setChecked(True)
parameters.params['main_filters_ring_removal_ufo_lpf_1d_or_2d'] = False
self.sigma_horizontal_entry.setText(
str(parameters.params['main_filters_ring_removal_ufo_lpf_sigma_horizontal'])
)
self.sigma_vertical_entry.setText(
str(parameters.params['main_filters_ring_removal_ufo_lpf_sigma_vertical'])
)
self.wind_size_entry.setText("21")
self.remove_wide_checkbox.setChecked(False)
parameters.params['main_filters_ring_removal_sarepy_wide'] = False
self.remove_wide_entry.setText("91")
self.SNR_entry.setText("3")
def set_values_from_params(self):
self.remove_spots_checkBox.setChecked(parameters.params['main_filters_remove_spots'])
self.threshold_entry.setText(str(parameters.params['main_filters_remove_spots_threshold']))
self.spot_blur_entry.setText(str(parameters.params['main_filters_remove_spots_blur_sigma']))
self.enable_RR_checkbox.setChecked(parameters.params['main_filters_ring_removal'])
if parameters.params['main_filters_ring_removal_ufo_lpf'] == True:
self.use_LPF_rButton.setChecked(True)
elif parameters.params['main_filters_ring_removal_ufo_lpf'] == False:
self.use_LPF_rButton.setChecked(False)
if parameters.params['main_filters_ring_removal_ufo_lpf_1d_or_2d'] == True:
self.one_dimens_rButton.setChecked(True)
self.two_dimens_rButton.setChecked(False)
elif parameters.params['main_filters_ring_removal_ufo_lpf_1d_or_2d'] == False:
self.two_dimens_rButton.setChecked(True)
self.one_dimens_rButton.setChecked(False)
self.sigma_horizontal_entry.setText(str(parameters.params['main_filters_ring_removal_ufo_lpf_sigma_horizontal']))
self.sigma_vertical_entry.setText(str(parameters.params['main_filters_ring_removal_ufo_lpf_sigma_vertical']))
self.wind_size_entry.setText(str(parameters.params['main_filters_ring_removal_sarepy_window_size']))
self.remove_wide_checkbox.setChecked(parameters.params['main_filters_ring_removal_sarepy_wide'])
self.remove_wide_entry.setText(str(parameters.params['main_filters_ring_removal_sarepy_window']))
self.SNR_entry.setText(str(parameters.params['main_filters_ring_removal_sarepy_SNR']))
def set_remove_spots(self):
LOG.debug("Remove large spots:" + str(self.remove_spots_checkBox.isChecked()))
parameters.params['main_filters_remove_spots'] = bool(self.remove_spots_checkBox.isChecked())
def set_threshold(self):
LOG.debug(self.threshold_entry.text())
parameters.params['main_filters_remove_spots_threshold'] = str(self.threshold_entry.text())
def set_spot_blur(self):
LOG.debug(self.spot_blur_entry.text())
parameters.params['main_filters_remove_spots_blur_sigma'] = str(self.spot_blur_entry.text())
def set_ring_removal(self):
LOG.debug("RR: " + str(self.enable_RR_checkbox.isChecked()))
parameters.params['main_filters_ring_removal'] = bool(self.enable_RR_checkbox.isChecked())
def select_rButton(self):
if self.use_LPF_rButton.isChecked():
LOG.debug("Use LPF")
parameters.params['main_filters_ring_removal_ufo_lpf'] = bool(True)
elif self.sarepy_rButton.isChecked():
LOG.debug("Use Sarepy")
parameters.params['main_filters_ring_removal_ufo_lpf'] = bool(False)
def select_dimens_rButton(self):
if self.one_dimens_rButton.isChecked():
LOG.debug("One dimension")
parameters.params['main_filters_ring_removal_ufo_lpf_1d_or_2d'] = bool(True)
elif self.two_dimens_rButton.isChecked():
LOG.debug("Two dimensions")
parameters.params['main_filters_ring_removal_ufo_lpf_1d_or_2d'] = bool(False)
def set_sigma_horizontal(self):
LOG.debug(self.sigma_horizontal_entry.text())
parameters.params['main_filters_ring_removal_ufo_lpf_sigma_horizontal'] = \
str(self.sigma_horizontal_entry.text())
def set_sigma_vertical(self):
LOG.debug(self.sigma_vertical_entry.text())
parameters.params['main_filters_ring_removal_ufo_lpf_sigma_vertical'] = \
str(self.sigma_vertical_entry.text())
def set_window_size(self):
LOG.debug(self.wind_size_entry.text())
parameters.params['main_filters_ring_removal_sarepy_window_size'] = str(self.wind_size_entry.text())
def set_remove_wide(self):
LOG.debug("Wide: " + str(self.remove_wide_checkbox.isChecked()))
parameters.params['main_filters_ring_removal_sarepy_wide'] = bool(self.remove_wide_checkbox.isChecked())
def set_wind(self):
LOG.debug(self.remove_wide_entry.text())
parameters.params['main_filters_ring_removal_sarepy_window'] = str(self.remove_wide_entry.text())
def set_SNR(self):
LOG.debug(self.SNR_entry.text())
parameters.params['main_filters_ring_removal_sarepy_SNR'] = str(self.SNR_entry.text()) tofu-0.12.0/tofu/ez/GUI/Main/phase_retrieval.py 0000664 0000000 0000000 00000010231 14237137211 0021177 0 ustar 00root root 0000000 0000000 import logging
from PyQt5.QtWidgets import QGridLayout, QLabel, QGroupBox, QLineEdit, QCheckBox
import tofu.ez.params as parameters
LOG = logging.getLogger(__name__)
class PhaseRetrievalGroup(QGroupBox):
"""
Phase Retrieval settings
"""
def __init__(self):
super().__init__()
self.setTitle("Phase Retrieval")
self.setStyleSheet("QGroupBox {color: blue;}")
self.enable_PR_checkBox = QCheckBox()
self.enable_PR_checkBox.setText("Enable Paganin/TIE phase retrieval")
self.enable_PR_checkBox.stateChanged.connect(self.set_PR)
self.photon_energy_label = QLabel()
self.photon_energy_label.setText("Photon energy [keV]")
self.photon_energy_entry = QLineEdit()
self.photon_energy_entry.editingFinished.connect(self.set_photon_energy)
self.photon_energy_entry.setStyleSheet("background-color:white")
self.pixel_size_label = QLabel()
self.pixel_size_label.setText("Pixel size [micron]")
self.pixel_size_entry = QLineEdit()
self.pixel_size_entry.editingFinished.connect(self.set_pixel_size)
self.pixel_size_entry.setStyleSheet("background-color:white")
self.detector_distance_label = QLabel()
self.detector_distance_label.setText("Sample-detector distance [m]")
self.detector_distance_entry = QLineEdit()
self.detector_distance_entry.editingFinished.connect(self.set_detector_distance)
self.detector_distance_entry.setStyleSheet("background-color:white")
self.delta_beta_ratio_label = QLabel()
self.delta_beta_ratio_label.setText("Delta/beta ratio: (try default if unsure)")
self.delta_beta_ratio_entry = QLineEdit()
self.delta_beta_ratio_entry.editingFinished.connect(self.set_delta_beta)
self.delta_beta_ratio_entry.setStyleSheet("background-color:white")
self.set_layout()
def set_layout(self):
layout = QGridLayout()
layout.addWidget(self.enable_PR_checkBox, 0, 0)
layout.addWidget(self.photon_energy_label, 1, 0)
layout.addWidget(self.photon_energy_entry, 1, 1)
layout.addWidget(self.pixel_size_label, 2, 0)
layout.addWidget(self.pixel_size_entry, 2, 1)
layout.addWidget(self.detector_distance_label, 3, 0)
layout.addWidget(self.detector_distance_entry, 3, 1)
layout.addWidget(self.delta_beta_ratio_label, 4, 0)
layout.addWidget(self.delta_beta_ratio_entry, 4, 1)
self.setLayout(layout)
def init_values(self):
self.enable_PR_checkBox.setChecked(False)
parameters.params['main_pr_phase_retrieval'] = False
self.photon_energy_entry.setText("20")
self.pixel_size_entry.setText("3.6")
self.detector_distance_entry.setText("0.1")
self.delta_beta_ratio_entry.setText("200")
def set_values_from_params(self):
self.enable_PR_checkBox.setChecked(parameters.params['main_pr_phase_retrieval'])
self.photon_energy_entry.setText(str(parameters.params['main_pr_photon_energy']))
self.pixel_size_entry.setText(str(parameters.params['main_pr_pixel_size']))
self.detector_distance_entry.setText(str(parameters.params['main_pr_detector_distance']))
self.delta_beta_ratio_entry.setText(str(parameters.params['main_pr_delta_beta_ratio']))
def set_PR(self):
LOG.debug("PR: " + str(self.enable_PR_checkBox.isChecked()))
parameters.params['main_pr_phase_retrieval'] = bool(self.enable_PR_checkBox.isChecked())
def set_photon_energy(self):
LOG.debug(self.photon_energy_entry.text())
parameters.params['main_pr_photon_energy'] = str(self.photon_energy_entry.text())
def set_pixel_size(self):
LOG.debug(self.pixel_size_entry.text())
parameters.params['main_pr_pixel_size'] = str(self.pixel_size_entry.text())
def set_detector_distance(self):
LOG.debug(self.detector_distance_entry.text())
parameters.params['main_pr_detector_distance'] = str(self.detector_distance_entry.text())
def set_delta_beta(self):
LOG.debug(self.delta_beta_ratio_entry.text())
parameters.params['main_pr_delta_beta_ratio'] = str(self.delta_beta_ratio_entry.text()) tofu-0.12.0/tofu/ez/GUI/Main/region_and_histogram.py 0000664 0000000 0000000 00000024573 14237137211 0022222 0 ustar 00root root 0000000 0000000 import logging
from PyQt5.QtWidgets import QGridLayout, QRadioButton, QLabel, QGroupBox, QLineEdit, QCheckBox
from PyQt5.QtCore import Qt
import tofu.ez.params as parameters
LOG = logging.getLogger(__name__)
class ROIandHistGroup(QGroupBox):
"""
Binning settings
"""
def __init__(self):
super().__init__()
self.setTitle("Region of Interest and Histogram Settings")
self.setStyleSheet("QGroupBox {color: red;}")
self.select_rows_checkbox = QCheckBox()
self.select_rows_checkbox.setText("Select rows which will be reconstructed")
self.select_rows_checkbox.stateChanged.connect(self.set_select_rows)
self.first_row_label = QLabel()
self.first_row_label.setText("First row in projections")
self.first_row_label.setToolTip("Counting from the top")
self.first_row_entry = QLineEdit()
self.first_row_entry.editingFinished.connect(self.set_first_row)
self.num_rows_label = QLabel()
self.num_rows_label.setText("Number of rows (ROI height)")
self.num_rows_entry = QLineEdit()
self.num_rows_entry.editingFinished.connect(self.set_num_rows)
self.nth_row_label = QLabel()
self.nth_row_label.setText("Step (reconstruct every Nth row)")
self.nth_row_entry = QLineEdit()
self.nth_row_entry.editingFinished.connect(self.set_reco_nth_rows)
self.clip_histo_checkbox = QCheckBox()
self.clip_histo_checkbox.setText("Clip histogram and save slices in")
self.clip_histo_checkbox.stateChanged.connect(self.set_clip_histo)
self.eight_bit_rButton = QRadioButton()
self.eight_bit_rButton.setText("8-bit")
self.eight_bit_rButton.setChecked(True)
self.eight_bit_rButton.clicked.connect(self.set_bitdepth)
self.sixteen_bit_rButton = QRadioButton()
self.sixteen_bit_rButton.setText("16-bit")
self.sixteen_bit_rButton.clicked.connect(self.set_bitdepth)
self.min_val_label = QLabel()
self.min_val_label.setText("Min value in 32-bit histogram")
self.min_val_entry = QLineEdit()
self.min_val_entry.editingFinished.connect(self.set_min_val)
self.max_val_label = QLabel()
self.max_val_label.setText("Max value in 32-bit histogram")
self.max_val_entry = QLineEdit()
self.max_val_entry.editingFinished.connect(self.set_max_val)
self.crop_slices_checkbox = QCheckBox()
self.crop_slices_checkbox.setText("Crop slices")
self.crop_slices_checkbox.setToolTip("Crop slices in the reconstruction plane \n"
"(x,y) - top left corner of selection \n"
"(width, height) - size of selection")
self.crop_slices_checkbox.stateChanged.connect(self.set_crop_slices)
self.x_val_label = QLabel()
self.x_val_label.setText("x")
self.x_val_label.setToolTip("First column (counting from left)")
self.x_val_entry = QLineEdit()
self.x_val_entry.editingFinished.connect(self.set_x)
self.width_val_label = QLabel()
self.width_val_label.setText("width")
self.width_val_entry = QLineEdit()
self.width_val_entry.editingFinished.connect(self.set_width)
self.y_val_label = QLabel()
self.y_val_label.setText("y")
self.y_val_label.setToolTip("First row (counting from top)")
self.y_val_entry = QLineEdit()
self.y_val_entry.editingFinished.connect(self.set_y)
self.height_val_label = QLabel()
self.height_val_label.setText("height")
self.height_val_entry = QLineEdit()
self.height_val_entry.editingFinished.connect(self.set_height)
self.rotate_vol_label = QLabel()
self.rotate_vol_label.setText("Rotate volume clockwise by [deg]")
self.rotate_vol_entry = QLineEdit()
self.rotate_vol_entry.editingFinished.connect(self.set_rotate_volume)
# self.setStyleSheet('background-color:Azure')
self.set_layout()
def set_layout(self):
"""
Sets the layout of buttons, labels, etc. for binning group
"""
layout = QGridLayout()
layout.addWidget(self.select_rows_checkbox, 0, 0)
layout.addWidget(self.first_row_label, 1, 0)
layout.addWidget(self.first_row_entry, 1, 1, 1, 8)
layout.addWidget(self.num_rows_label, 2, 0)
layout.addWidget(self.num_rows_entry, 2, 1, 1, 8)
layout.addWidget(self.nth_row_label, 3, 0)
layout.addWidget(self.nth_row_entry, 3, 1, 1, 8)
layout.addWidget(self.clip_histo_checkbox, 4, 0)
layout.addWidget(self.eight_bit_rButton, 4, 1)
layout.addWidget(self.sixteen_bit_rButton, 4, 2)
layout.addWidget(self.min_val_label, 5, 0)
layout.addWidget(self.min_val_entry, 5, 1, 1, 8)
layout.addWidget(self.max_val_label, 6, 0)
layout.addWidget(self.max_val_entry, 6, 1, 1, 8)
layout.addWidget(self.crop_slices_checkbox, 7, 0)
layout.addWidget(self.x_val_label, 7, 1)#, Qt.AlignRight)
layout.addWidget(self.x_val_entry, 7, 2)
layout.addWidget(self.width_val_label, 7, 3)#, Qt.AlignRight)
layout.addWidget(self.width_val_entry, 7, 4)
layout.addWidget(self.y_val_label, 7, 5)
layout.addWidget(self.y_val_entry, 7, 6)
layout.addWidget(self.height_val_label, 7, 7)
layout.addWidget(self.height_val_entry, 7, 8)
layout.addWidget(self.rotate_vol_label, 8, 0)
layout.addWidget(self.rotate_vol_entry, 8, 1, 1, 8)
self.setLayout(layout)
def init_values(self):
self.select_rows_checkbox.setChecked(False)
parameters.params['main_region_select_rows'] = False
self.first_row_entry.setText("100")
self.num_rows_entry.setText("200")
self.nth_row_entry.setText("20")
self.clip_histo_checkbox.setChecked(False)
parameters.params['main_region_clip_histogram'] = False
self.eight_bit_rButton.setChecked(True)
parameters.params['main_region_bit_depth'] = str(8)
self.min_val_entry.setText("0.0")
self.max_val_entry.setText("0.0")
self.crop_slices_checkbox.setChecked(False)
parameters.params['main_region_crop_slices'] = False
self.x_val_entry.setText("0")
self.width_val_entry.setText("0")
self.y_val_entry.setText("0")
self.height_val_entry.setText("0")
self.rotate_vol_entry.setText("0.0")
def set_values_from_params(self):
self.select_rows_checkbox.setChecked(parameters.params['main_region_select_rows'])
self.first_row_entry.setText(str(parameters.params['main_region_first_row']))
self.num_rows_entry.setText(str(parameters.params['main_region_number_rows']))
self.nth_row_entry.setText(str(parameters.params['main_region_nth_row']))
self.clip_histo_checkbox.setChecked(parameters.params['main_region_clip_histogram'])
if int(parameters.params['main_region_bit_depth']) == 8:
self.eight_bit_rButton.setChecked(True)
self.sixteen_bit_rButton.setChecked(False)
elif int(parameters.params['main_region_bit_depth']) == 16:
self.eight_bit_rButton.setChecked(False)
self.sixteen_bit_rButton.setChecked(True)
self.min_val_entry.setText(str(parameters.params['main_region_histogram_min']))
self.max_val_entry.setText(str(parameters.params['main_region_histogram_max']))
self.crop_slices_checkbox.setChecked(parameters.params['main_region_crop_slices'])
self.x_val_entry.setText(str(parameters.params['main_region_crop_x']))
self.width_val_entry.setText(str(parameters.params['main_region_crop_width']))
self.y_val_entry.setText(str(parameters.params['main_region_crop_y']))
self.height_val_entry.setText(str(parameters.params['main_region_crop_height']))
self.rotate_vol_entry.setText(str(parameters.params['main_region_rotate_volume_clock']))
def set_select_rows(self):
LOG.debug("Select rows: " + str(self.select_rows_checkbox.isChecked()))
parameters.params['main_region_select_rows'] = bool(self.select_rows_checkbox.isChecked())
def set_first_row(self):
LOG.debug(self.first_row_entry.text())
parameters.params['main_region_first_row'] = str(self.first_row_entry.text())
def set_num_rows(self):
LOG.debug(self.num_rows_entry.text())
parameters.params['main_region_number_rows'] = str(self.num_rows_entry.text())
def set_reco_nth_rows(self):
LOG.debug(self.nth_row_entry.text())
parameters.params['main_region_nth_row'] = str(self.nth_row_entry.text())
def set_clip_histo(self):
LOG.debug("Clip histo: " + str(self.clip_histo_checkbox.isChecked()))
parameters.params['main_region_clip_histogram'] = bool(self.clip_histo_checkbox.isChecked())
def set_bitdepth(self):
if self.eight_bit_rButton.isChecked():
LOG.debug("8 bit")
parameters.params['main_region_bit_depth'] = str(8)
elif self.sixteen_bit_rButton.isChecked():
LOG.debug("16 bit")
parameters.params['main_region_bit_depth'] = str(16)
def set_min_val(self):
LOG.debug(self.min_val_entry.text())
parameters.params['main_region_histogram_min'] = str(self.min_val_entry.text())
def set_max_val(self):
LOG.debug(self.max_val_entry.text())
parameters.params['main_region_histogram_max'] = str(self.max_val_entry.text())
def set_crop_slices(self):
LOG.debug("Crop slices: " + str(self.crop_slices_checkbox.isChecked()))
parameters.params['main_region_crop_slices'] = bool(self.crop_slices_checkbox.isChecked())
def set_x(self):
LOG.debug(self.x_val_entry.text())
parameters.params['main_region_crop_x'] = str(self.x_val_entry.text())
def set_width(self):
LOG.debug(self.width_val_entry.text())
parameters.params['main_region_crop_width'] = str(self.width_val_entry.text())
def set_y(self):
LOG.debug(self.y_val_entry.text())
parameters.params['main_region_crop_y'] = str(self.y_val_entry.text())
def set_height(self):
LOG.debug(self.height_val_entry.text())
parameters.params['main_region_crop_height'] = str(self.height_val_entry.text())
def set_rotate_volume(self):
LOG.debug(self.rotate_vol_entry.text())
parameters.params["main_region_rotate_volume_clock"] = str(self.rotate_vol_entry.text())
tofu-0.12.0/tofu/ez/GUI/Stitch_tools_tab/ 0000775 0000000 0000000 00000000000 14237137211 0020073 5 ustar 00root root 0000000 0000000 tofu-0.12.0/tofu/ez/GUI/Stitch_tools_tab/__init__.py 0000664 0000000 0000000 00000000000 14237137211 0022172 0 ustar 00root root 0000000 0000000 tofu-0.12.0/tofu/ez/GUI/Stitch_tools_tab/ez_360_multi_stitch_qt.py 0000664 0000000 0000000 00000051664 14237137211 0024763 0 ustar 00root root 0000000 0000000 from PyQt5.QtWidgets import (
QGroupBox,
QPushButton,
QCheckBox,
QLabel,
QLineEdit,
QGridLayout,
QFileDialog,
QMessageBox,
)
from PyQt5.QtCore import pyqtSignal
import logging
from shutil import rmtree
import os
import getpass
import yaml
from tofu.ez.Helpers.stitch_funcs import main_360_mp_depth2
from tofu.ez.GUI.message_dialog import warning_message
# Params
import tofu.ez.params as params
LOG = logging.getLogger(__name__)
class MultiStitch360Group(QGroupBox):
get_fdt_names_on_stitch_pressed = pyqtSignal()
def __init__(self):
super().__init__()
self.setTitle("Batch horizontal stitching of half-acquistion mode data sets")
self.setStyleSheet('QGroupBox {color: red;}')
self.input_dir_button = QPushButton("Select input directory")
self.input_dir_button.clicked.connect(self.input_button_pressed)
self.input_dir_entry = QLineEdit()
self.input_dir_entry.editingFinished.connect(self.set_input_entry)
self.temp_dir_button = QPushButton("Select temporary directory - default value recommended")
self.temp_dir_button.clicked.connect(self.temp_button_pressed)
self.temp_dir_entry = QLineEdit()
self.temp_dir_entry.editingFinished.connect(self.set_temp_entry)
self.output_dir_button = QPushButton("Directory to save stitched images")
self.output_dir_button.clicked.connect(self.output_button_pressed)
self.output_dir_entry = QLineEdit()
self.output_dir_entry.editingFinished.connect(self.set_output_entry)
self.crop_checkbox = QCheckBox("Crop all projections to match the width of smallest stitched projection")
self.crop_checkbox.clicked.connect(self.set_crop_projections_checkbox)
self.axis_bottom_label = QLabel()
self.axis_bottom_label.setText("Axis of Rotation (Dir 00):")
self.axis_bottom_entry = QLineEdit()
self.axis_bottom_entry.editingFinished.connect(self.set_axis_bottom)
self.axis_top_label = QLabel("Axis of Rotation (Dir 0N):")
self.axis_group = QGroupBox("Enter axis of rotation manually")
self.axis_group.clicked.connect(self.set_axis_group)
self.axis_top_entry = QLineEdit()
self.axis_top_entry.editingFinished.connect(self.set_axis_top)
self.axis_z000_label = QLabel("Axis of Rotation (Dir 00):")
self.axis_z000_entry = QLineEdit()
self.axis_z000_entry.editingFinished.connect(self.set_z000)
self.axis_z001_label = QLabel("Axis of Rotation (Dir 01):")
self.axis_z001_entry = QLineEdit()
self.axis_z001_entry.editingFinished.connect(self.set_z001)
self.axis_z002_label = QLabel("Axis of Rotation (Dir 02):")
self.axis_z002_entry = QLineEdit()
self.axis_z002_entry.editingFinished.connect(self.set_z002)
self.axis_z003_label = QLabel("Axis of Rotation (Dir 03):")
self.axis_z003_entry = QLineEdit()
self.axis_z003_entry.editingFinished.connect(self.set_z003)
self.axis_z004_label = QLabel("Axis of Rotation (Dir 04):")
self.axis_z004_entry = QLineEdit()
self.axis_z004_entry.editingFinished.connect(self.set_z004)
self.axis_z005_label = QLabel("Axis of Rotation (Dir 05):")
self.axis_z005_entry = QLineEdit()
self.axis_z005_entry.editingFinished.connect(self.set_z005)
self.axis_z006_label = QLabel("Axis of Rotation (Dir 06):")
self.axis_z006_entry = QLineEdit()
self.axis_z006_entry.editingFinished.connect(self.set_z006)
self.axis_z007_label = QLabel("Axis of Rotation (Dir 07):")
self.axis_z007_entry = QLineEdit()
self.axis_z007_entry.editingFinished.connect(self.set_z007)
self.axis_z008_label = QLabel("Axis of Rotation (Dir 08):")
self.axis_z008_entry = QLineEdit()
self.axis_z008_entry.editingFinished.connect(self.set_z008)
self.axis_z009_label = QLabel("Axis of Rotation (Dir 09):")
self.axis_z009_entry = QLineEdit()
self.axis_z009_entry.editingFinished.connect(self.set_z009)
self.axis_z010_label = QLabel("Axis of Rotation (Dir 10):")
self.axis_z010_entry = QLineEdit()
self.axis_z010_entry.editingFinished.connect(self.set_z010)
self.axis_z011_label = QLabel("Axis of Rotation (Dir 11):")
self.axis_z011_entry = QLineEdit()
self.axis_z011_entry.editingFinished.connect(self.set_z011)
self.stitch_button = QPushButton("Stitch")
self.stitch_button.clicked.connect(self.stitch_button_pressed)
self.stitch_button.setStyleSheet("color:royalblue;font-weight:bold")
self.delete_button = QPushButton("Delete output dir")
self.delete_button.clicked.connect(self.delete_button_pressed)
self.help_button = QPushButton("Help")
self.help_button.clicked.connect(self.help_button_pressed)
self.import_parameters_button = QPushButton("Import Parameters from File")
self.import_parameters_button.clicked.connect(self.import_parameters_button_pressed)
self.save_parameters_button = QPushButton("Save Parameters to File")
self.save_parameters_button.clicked.connect(self.save_parameters_button_pressed)
self.set_layout()
def set_layout(self):
layout = QGridLayout()
layout.addWidget(self.input_dir_button, 0, 0, 1, 4)
layout.addWidget(self.input_dir_entry, 1, 0, 1, 4)
layout.addWidget(self.temp_dir_button, 2, 0, 1, 4)
layout.addWidget(self.temp_dir_entry, 3, 0, 1, 4)
layout.addWidget(self.output_dir_button, 4, 0, 1, 4)
layout.addWidget(self.output_dir_entry, 5, 0, 1, 4)
layout.addWidget(self.crop_checkbox, 6, 0, 1, 4)
layout.addWidget(self.axis_bottom_label, 7, 0)
layout.addWidget(self.axis_bottom_entry, 7, 1)
layout.addWidget(self.axis_top_label, 7, 2)
layout.addWidget(self.axis_top_entry, 7, 3)
self.axis_group.setCheckable(True)
self.axis_group.setChecked(False)
axis_layout = QGridLayout()
axis_layout.addWidget(self.axis_z000_label, 0, 0)
axis_layout.addWidget(self.axis_z000_entry, 0, 1)
axis_layout.addWidget(self.axis_z006_label, 0, 2)
axis_layout.addWidget(self.axis_z006_entry, 0, 3)
axis_layout.addWidget(self.axis_z001_label, 1, 0)
axis_layout.addWidget(self.axis_z001_entry, 1, 1)
axis_layout.addWidget(self.axis_z007_label, 1, 2)
axis_layout.addWidget(self.axis_z007_entry, 1, 3)
axis_layout.addWidget(self.axis_z002_label, 2, 0)
axis_layout.addWidget(self.axis_z002_entry, 2, 1)
axis_layout.addWidget(self.axis_z008_label, 2, 2)
axis_layout.addWidget(self.axis_z008_entry, 2, 3)
axis_layout.addWidget(self.axis_z003_label, 3, 0)
axis_layout.addWidget(self.axis_z003_entry, 3, 1)
axis_layout.addWidget(self.axis_z009_label, 3, 2)
axis_layout.addWidget(self.axis_z009_entry, 3, 3)
axis_layout.addWidget(self.axis_z004_label, 4, 0)
axis_layout.addWidget(self.axis_z004_entry, 4, 1)
axis_layout.addWidget(self.axis_z010_label, 4, 2)
axis_layout.addWidget(self.axis_z010_entry, 4, 3)
axis_layout.addWidget(self.axis_z005_label, 5, 0)
axis_layout.addWidget(self.axis_z005_entry, 5, 1)
axis_layout.addWidget(self.axis_z011_label, 5, 2)
axis_layout.addWidget(self.axis_z011_entry, 5, 3)
self.axis_group.setLayout(axis_layout)
self.axis_group.setTabOrder(self.axis_z000_entry, self.axis_z001_entry)
self.axis_group.setTabOrder(self.axis_z001_entry, self.axis_z002_entry)
self.axis_group.setTabOrder(self.axis_z002_entry, self.axis_z003_entry)
self.axis_group.setTabOrder(self.axis_z003_entry, self.axis_z004_entry)
self.axis_group.setTabOrder(self.axis_z004_entry, self.axis_z005_entry)
self.axis_group.setTabOrder(self.axis_z005_entry, self.axis_z006_entry)
self.axis_group.setTabOrder(self.axis_z006_entry, self.axis_z007_entry)
self.axis_group.setTabOrder(self.axis_z007_entry, self.axis_z008_entry)
self.axis_group.setTabOrder(self.axis_z008_entry, self.axis_z009_entry)
self.axis_group.setTabOrder(self.axis_z009_entry, self.axis_z010_entry)
self.axis_group.setTabOrder(self.axis_z010_entry, self.axis_z011_entry)
layout.addWidget(self.axis_group, 8, 0, 1, 4)
layout.addWidget(self.help_button, 9, 0)
layout.addWidget(self.delete_button, 9, 1)
layout.addWidget(self.stitch_button, 9, 2, 1, 2)
layout.addWidget(self.import_parameters_button, 10, 0, 1, 2)
layout.addWidget(self.save_parameters_button, 10, 2, 1, 2)
self.setLayout(layout)
def init_values(self):
self.parameters = {'parameters_type': '360_multi_stitch'}
self.parameters['360multi_input_dir'] = os.path.expanduser('~')#"~/"#os.getcwd()
self.input_dir_entry.setText(self.parameters['360multi_input_dir'])
self.parameters['360multi_temp_dir'] = os.path.join(
os.path.expanduser('~'), "tmp-batch360stitch")
self.temp_dir_entry.setText(self.parameters['360multi_temp_dir'])
self.parameters['360multi_output_dir'] = os.path.expanduser('~')
self.output_dir_entry.setText(self.parameters['360multi_output_dir'])
self.parameters['360multi_crop_projections'] = True
self.crop_checkbox.setChecked(self.parameters['360multi_crop_projections'])
self.parameters['360multi_bottom_axis'] = 245
self.axis_bottom_entry.setText(str(self.parameters['360multi_bottom_axis']))
self.parameters['360multi_top_axis'] = 245
self.axis_top_entry.setText(str(self.parameters['360multi_top_axis']))
self.parameters['360multi_axis'] = self.parameters['360multi_bottom_axis']
self.parameters['360multi_manual_axis'] = False
self.parameters['360multi_axis_dict'] = dict.fromkeys(['z000', 'z001', 'z002', 'z003', 'z004', 'z005',
'z006', 'z007', 'z008', 'z009', 'z010', 'z011'], 200)
def update_parameters(self, new_parameters):
LOG.debug("Update parameters")
if new_parameters['parameters_type'] != '360_multi_stitch':
print("Error: Invalid parameter file type: " + str(new_parameters['parameters_type']))
return -1
# Update parameters dictionary (which is passed to auto_stitch_funcs)
self.parameters = new_parameters
# Update displayed parameters for GUI
self.input_dir_entry.setText(self.parameters['360multi_input_dir'])
self.temp_dir_entry.setText(self.parameters['360multi_temp_dir'])
self.output_dir_entry.setText(self.parameters['360multi_output_dir'])
self.crop_checkbox.setChecked(self.parameters['360multi_crop_projections'])
self.axis_bottom_entry.setText(str(self.parameters['360multi_bottom_axis']))
self.axis_top_entry.setText(str(self.parameters['360multi_top_axis']))
self.axis_group.setChecked(bool(self.parameters['360multi_manual_axis']))
self.axis_z000_entry.setText(str(self.parameters['360multi_axis_dict']['z000']))
self.axis_z001_entry.setText(str(self.parameters['360multi_axis_dict']['z001']))
self.axis_z002_entry.setText(str(self.parameters['360multi_axis_dict']['z002']))
self.axis_z003_entry.setText(str(self.parameters['360multi_axis_dict']['z003']))
self.axis_z004_entry.setText(str(self.parameters['360multi_axis_dict']['z004']))
self.axis_z005_entry.setText(str(self.parameters['360multi_axis_dict']['z005']))
self.axis_z006_entry.setText(str(self.parameters['360multi_axis_dict']['z006']))
self.axis_z007_entry.setText(str(self.parameters['360multi_axis_dict']['z007']))
self.axis_z008_entry.setText(str(self.parameters['360multi_axis_dict']['z008']))
self.axis_z009_entry.setText(str(self.parameters['360multi_axis_dict']['z009']))
self.axis_z010_entry.setText(str(self.parameters['360multi_axis_dict']['z010']))
self.axis_z011_entry.setText(str(self.parameters['360multi_axis_dict']['z011']))
return 0
def input_button_pressed(self):
LOG.debug("Input button pressed")
dir_explore = QFileDialog(self)
self.parameters['360multi_input_dir'] = dir_explore.getExistingDirectory()
self.input_dir_entry.setText(self.parameters['360multi_input_dir'])
def set_input_entry(self):
LOG.debug("Input directory: " + str(self.input_dir_entry.text()))
self.parameters['360multi_input_dir'] = str(self.input_dir_entry.text())
def temp_button_pressed(self):
LOG.debug("Temp button pressed")
dir_explore = QFileDialog(self)
self.parameters['360multi_temp_dir'] = dir_explore.getExistingDirectory()
self.temp_dir_entry.setText(self.parameters['360multi_temp_dir'])
def set_temp_entry(self):
LOG.debug("Temp directory: " + str(self.temp_dir_entry.text()))
self.parameters['360multi_temp_dir'] = str(self.temp_dir_entry.text())
def output_button_pressed(self):
LOG.debug("Output button pressed")
dir_explore = QFileDialog(self)
self.parameters['360multi_output_dir'] = dir_explore.getExistingDirectory()
self.output_dir_entry.setText(self.parameters['360multi_output_dir'])
def set_output_entry(self):
LOG.debug("Output directory: " + str(self.output_dir_entry.text()))
self.parameters['360multi_output_dir'] = str(self.output_dir_entry.text())
def set_crop_projections_checkbox(self):
LOG.debug("Crop projections: " + str(self.crop_checkbox.isChecked()))
self.parameters['360multi_crop_projections'] = bool(self.crop_checkbox.isChecked())
def set_axis_bottom(self):
LOG.debug("Axis Bottom : " + str(self.axis_bottom_entry.text()))
self.parameters['360multi_bottom_axis'] = int(self.axis_bottom_entry.text())
def set_axis_top(self):
LOG.debug("Axis Top: " + str(self.axis_top_entry.text()))
self.parameters['360multi_top_axis'] = int(self.axis_top_entry.text())
def set_axis_group(self):
if self.axis_group.isChecked():
self.axis_bottom_label.setEnabled(False)
self.axis_bottom_entry.setEnabled(False)
self.axis_top_label.setEnabled(False)
self.axis_top_entry.setEnabled(False)
self.parameters['360multi_manual_axis'] = True
LOG.debug("Enter axis of rotation manually: " + str(self.parameters['360multi_manual_axis']))
else:
self.axis_bottom_label.setEnabled(True)
self.axis_bottom_entry.setEnabled(True)
self.axis_top_label.setEnabled(True)
self.axis_top_entry.setEnabled(True)
self.parameters['360multi_manual_axis'] = False
LOG.debug("Enter axis of rotation manually: " + str(self.parameters['360multi_manual_axis']))
def set_z000(self):
LOG.debug("z000 axis: " + str(self.axis_z000_entry.text()))
self.parameters['360multi_axis_dict']['z000'] = int(self.axis_z000_entry.text())
def set_z001(self):
LOG.debug("z001 axis: " + str(self.axis_z001_entry.text()))
self.parameters['360multi_axis_dict']['z001'] = int(self.axis_z001_entry.text())
def set_z002(self):
LOG.debug("z002 axis: " + str(self.axis_z002_entry.text()))
self.parameters['360multi_axis_dict']['z002'] = int(self.axis_z002_entry.text())
def set_z003(self):
LOG.debug("z003 axis: " + str(self.axis_z003_entry.text()))
self.parameters['360multi_axis_dict']['z003'] = int(self.axis_z003_entry.text())
def set_z004(self):
LOG.debug("z004 axis: " + str(self.axis_z004_entry.text()))
self.parameters['360multi_axis_dict']['z004'] = int(self.axis_z004_entry.text())
def set_z005(self):
LOG.debug("z005 axis: " + str(self.axis_z005_entry.text()))
self.parameters['360multi_axis_dict']['z005'] = int(self.axis_z005_entry.text())
def set_z006(self):
LOG.debug("z006 axis: " + str(self.axis_z006_entry.text()))
self.parameters['360multi_axis_dict']['z006'] = int(self.axis_z006_entry.text())
def set_z007(self):
LOG.debug("z007 axis: " + str(self.axis_z007_entry.text()))
self.parameters['360multi_axis_dict']['z007'] = int(self.axis_z007_entry.text())
def set_z008(self):
LOG.debug("z008 axis: " + str(self.axis_z008_entry.text()))
self.parameters['360multi_axis_dict']['z008'] = int(self.axis_z008_entry.text())
def set_z009(self):
LOG.debug("z009 axis: " + str(self.axis_z009_entry.text()))
self.parameters['360multi_axis_dict']['z009'] = int(self.axis_z009_entry.text())
def set_z010(self):
LOG.debug("z010 axis: " + str(self.axis_z010_entry.text()))
self.parameters['360multi_axis_dict']['z010'] = int(self.axis_z010_entry.text())
def set_z011(self):
LOG.debug("z011 axis: " + str(self.axis_z011_entry.text()))
self.parameters['360multi_axis_dict']['z011'] = int(self.axis_z011_entry.text())
def stitch_button_pressed(self):
LOG.debug("Stitch button pressed")
self.get_fdt_names_on_stitch_pressed.emit()
if os.path.exists(self.parameters['360multi_temp_dir']):
qm = QMessageBox()
rep = qm.question(self, '', "Temporary dir is not empty. Is it safe to delete it?", qm.Yes | qm.No)
if rep == qm.Yes:
try:
rmtree(self.parameters['360multi_temp_dir'])
except:
warning_message("Problems with deleting directory")
else:
return
if os.path.exists(self.parameters['360multi_output_dir']):
warning_message('Output directory exists. Delete it or select another one.')
return
# if os.path.exists(self.parameters['360multi_output_dir']):
# # raise ValueError('Output directory exists')
# qm = QMessageBox()
# rep = qm.question(self, '', "Output dir is not empty. Can I delete it?", qm.Yes | qm.No)
# if rep == qm.Yes:
# os.system('rm -r {}'.format(self.parameters['360multi_output_dir']))
# else:
# return
print("======= Begin 360 Multi-Stitch =======")
main_360_mp_depth2(self.parameters)
if os.path.isdir(self.parameters['360multi_output_dir']):
params_file_path = os.path.join(self.parameters['360multi_output_dir'], '360_multi_stitch_params.yaml')
params.save_parameters(self.parameters, params_file_path)
print("==== Waiting for Next Task ====")
def delete_button_pressed(self):
print("---- Deleting Data From Output Directory ----")
LOG.debug("Delete button pressed")
qm = QMessageBox()
rep = qm.question(self, '', "Is it safe to delete the directory?", qm.Yes | qm.No)
if rep == qm.Yes:
try:
rmtree(self.parameters['360multi_output_dir'])
except:
warning_message("Problems with deleting directory")
else:
return
def help_button_pressed(self):
LOG.debug("Help button pressed")
h = "Stitches images horizontally\n"
h += "Directory structure is, f.i., Input/000, Input/001,...Input/00N\n"
h += "Each 000, 001, ... 00N directory must have identical subdirectory \"Type\"\n"
h += "Selected range of images from \"Type\" directory will be stitched vertically\n"
h += "across all subdirectories in the Input directory"
h += "to be added as options:\n"
h += "(1) orthogonal reslicing, (2) interpolation, (3) horizontal stitching"
QMessageBox.information(self, "Help", h)
def import_parameters_button_pressed(self):
LOG.debug("Import params button clicked")
dir_explore = QFileDialog(self)
params_file_path = dir_explore.getOpenFileName(filter="*.yaml")
try:
file_in = open(params_file_path[0], 'r')
new_parameters = yaml.load(file_in, Loader=yaml.FullLoader)
if self.update_parameters(new_parameters) == 0:
print("Parameters file loaded from: " + str(params_file_path[0]))
except FileNotFoundError:
print("You need to select a valid input file")
def save_parameters_button_pressed(self):
LOG.debug("Save params button clicked")
dir_explore = QFileDialog(self)
params_file_path = dir_explore.getSaveFileName(filter="*.yaml")
garbage, file_name = os.path.split(params_file_path[0])
file_extension = os.path.splitext(file_name)
# If the user doesn't enter the .yaml extension then append it to filepath
if file_extension[-1] == "":
file_path = params_file_path[0] + ".yaml"
else:
file_path = params_file_path[0]
try:
file_out = open(file_path, 'w')
yaml.dump(self.parameters, file_out)
print("Parameters file saved at: " + str(file_path))
except FileNotFoundError:
print("You need to select a directory and use a valid file name")
tofu-0.12.0/tofu/ez/GUI/Stitch_tools_tab/ez_360_overlap_qt.py 0000664 0000000 0000000 00000030030 14237137211 0023703 0 ustar 00root root 0000000 0000000 from PyQt5.QtWidgets import (
QGroupBox,
QPushButton,
QCheckBox,
QLabel,
QLineEdit,
QGridLayout,
QFileDialog,
QMessageBox,
)
from PyQt5.QtCore import pyqtSignal
import logging
from shutil import rmtree
import yaml
import os
from tofu.ez.Helpers.find_360_overlap import find_overlap
import tofu.ez.params as params
import getpass
#TODO Make all stitching tools compatible with the bigtiffs
LOG = logging.getLogger(__name__)
class Overlap360Group(QGroupBox):
get_fdt_names_on_stitch_pressed = pyqtSignal()
def __init__(self):
super().__init__()
self.setTitle("Reconstruct one slice with different axis of rotation positions for half-acqusition mode data set(s)")
self.setStyleSheet('QGroupBox {color: Orange;}')
self.input_dir_button = QPushButton("Select input directory")
self.input_dir_button.clicked.connect(self.input_button_pressed)
self.input_dir_entry = QLineEdit()
self.input_dir_entry.editingFinished.connect(self.set_input_entry)
self.temp_dir_button = QPushButton("Select temp directory")
self.temp_dir_button.clicked.connect(self.temp_button_pressed)
self.temp_dir_entry = QLineEdit()
self.temp_dir_entry.editingFinished.connect(self.set_temp_entry)
self.output_dir_button = QPushButton("Select output directory")
self.output_dir_button.clicked.connect(self.output_button_pressed)
self.output_dir_entry = QLineEdit()
self.output_dir_entry.editingFinished.connect(self.set_output_entry)
self.pixel_row_label = QLabel("Pixel row to be used for sinogram")
self.pixel_row_entry = QLineEdit()
self.pixel_row_entry.editingFinished.connect(self.set_pixel_row)
self.min_label = QLabel("Lower limit of stitch/axis search range")
self.min_entry = QLineEdit()
self.min_entry.editingFinished.connect(self.set_lower_limit)
self.max_label = QLabel("Upper limit of stitch/axis search range")
self.max_entry = QLineEdit()
self.max_entry.editingFinished.connect(self.set_upper_limit)
self.step_label = QLabel("Value by which to increment through search range")
self.step_entry = QLineEdit()
self.step_entry.editingFinished.connect(self.set_increment)
self.axis_on_left = QCheckBox("Apply ring removal")
self.axis_on_left.setEnabled(False)
self.axis_on_left.stateChanged.connect(self.set_axis_checkbox)
self.help_button = QPushButton("Help")
self.help_button.clicked.connect(self.help_button_pressed)
self.find_overlap_button = QPushButton("Generate slices")
self.find_overlap_button.clicked.connect(self.overlap_button_pressed)
self.find_overlap_button.setStyleSheet("color:royalblue;font-weight:bold")
self.import_parameters_button = QPushButton("Import Parameters from File")
self.import_parameters_button.clicked.connect(self.import_parameters_button_pressed)
self.save_parameters_button = QPushButton("Save Parameters to File")
self.save_parameters_button.clicked.connect(self.save_parameters_button_pressed)
self.set_layout()
def set_layout(self):
layout = QGridLayout()
layout.addWidget(self.input_dir_button, 0, 0, 1, 2)
layout.addWidget(self.input_dir_entry, 1, 0, 1, 2)
layout.addWidget(self.temp_dir_button, 2, 0, 1, 2)
layout.addWidget(self.temp_dir_entry, 3, 0, 1, 2)
layout.addWidget(self.output_dir_button, 4, 0, 1, 2)
layout.addWidget(self.output_dir_entry, 5, 0, 1, 2)
layout.addWidget(self.pixel_row_label, 6, 0)
layout.addWidget(self.pixel_row_entry, 6, 1)
layout.addWidget(self.min_label, 7, 0)
layout.addWidget(self.min_entry, 7, 1)
layout.addWidget(self.max_label, 8, 0)
layout.addWidget(self.max_entry, 8, 1)
layout.addWidget(self.step_label, 9, 0)
layout.addWidget(self.step_entry, 9, 1)
layout.addWidget(self.axis_on_left, 10, 0)
layout.addWidget(self.help_button, 11, 0)
layout.addWidget(self.find_overlap_button, 11, 1)
layout.addWidget(self.import_parameters_button, 12, 0)
layout.addWidget(self.save_parameters_button, 12, 1)
self.setLayout(layout)
def init_values(self):
self.parameters = {'parameters_type': '360_overlap'}
self.parameters['360overlap_input_dir'] = os.path.expanduser('~')
self.input_dir_entry.setText(self.parameters['360overlap_input_dir'])
self.parameters['360overlap_temp_dir'] = os.path.join(
os.path.expanduser('~'), "tmp-360axis-search")
self.temp_dir_entry.setText(self.parameters['360overlap_temp_dir'])
self.parameters['360overlap_output_dir'] = os.path.join(
os.path.expanduser('~'), "ezufo-360axis-search")
self.output_dir_entry.setText(self.parameters['360overlap_output_dir'])
self.parameters['360overlap_row'] = 200
self.pixel_row_entry.setText(str(self.parameters['360overlap_row']))
self.parameters['360overlap_lower_limit'] = 100
self.min_entry.setText(str(self.parameters['360overlap_lower_limit']))
self.parameters['360overlap_upper_limit'] = 200
self.max_entry.setText(str(self.parameters['360overlap_upper_limit']))
self.parameters['360overlap_increment'] = 1
self.step_entry.setText(str(self.parameters['360overlap_increment']))
self.parameters['360overlap_axis_on_left'] = True
self.axis_on_left.setChecked(bool(self.parameters['360overlap_axis_on_left']))
def update_parameters(self, new_parameters):
LOG.debug("Update parameters")
if new_parameters['parameters_type'] != '360_overlap':
print("Error: Invalid parameter file type: " + str(new_parameters['parameters_type']))
return -1
# Update parameters dictionary (which is passed to auto_stitch_funcs)
self.parameters = new_parameters
# Update displayed parameters for GUI
self.input_dir_entry.setText(self.parameters['360overlap_input_dir'])
self.temp_dir_entry.setText(self.parameters['360overlap_temp_dir'])
self.output_dir_entry.setText(self.parameters['360overlap_output_dir'])
self.pixel_row_entry.setText(str(self.parameters['360overlap_row']))
self.min_entry.setText(str(self.parameters['360overlap_lower_limit']))
self.max_entry.setText(str(self.parameters['360overlap_upper_limit']))
self.step_entry.setText(str(self.parameters['360overlap_increment']))
self.axis_on_left.setChecked(bool(self.parameters['360overlap_axis_on_left']))
def input_button_pressed(self):
LOG.debug("Select input button pressed")
dir_explore = QFileDialog(self)
self.parameters['360overlap_input_dir'] = dir_explore.getExistingDirectory()
self.input_dir_entry.setText(self.parameters['360overlap_input_dir'])
def set_input_entry(self):
LOG.debug("Input: " + str(self.input_dir_entry.text()))
self.parameters['360overlap_input_dir'] = str(self.input_dir_entry.text())
def temp_button_pressed(self):
LOG.debug("Select temp button pressed")
dir_explore = QFileDialog(self)
self.parameters['360overlap_temp_dir'] = dir_explore.getExistingDirectory()
self.temp_dir_entry.setText(self.parameters['360overlap_temp_dir'])
def set_temp_entry(self):
LOG.debug("Temp: " + str(self.temp_dir_entry.text()))
self.parameters['360overlap_temp_dir'] = str(self.temp_dir_entry.text())
def output_button_pressed(self):
LOG.debug("Select output button pressed")
dir_explore = QFileDialog(self)
self.parameters['360overlap_output_dir'] = dir_explore.getExistingDirectory()
self.output_dir_entry.setText(self.parameters['360overlap_output_dir'])
def set_output_entry(self):
LOG.debug("Output: " + str(self.output_dir_entry.text()))
self.parameters['360overlap_output_dir'] = str(self.output_dir_entry.text())
def set_pixel_row(self):
LOG.debug("Pixel row: " + str(self.pixel_row_entry.text()))
self.parameters['360overlap_row'] = int(self.pixel_row_entry.text())
def set_lower_limit(self):
LOG.debug("Lower limit: " + str(self.min_entry.text()))
self.parameters['360overlap_lower_limit'] = int(self.min_entry.text())
def set_upper_limit(self):
LOG.debug("Upper limit: " + str(self.max_entry.text()))
self.parameters['360overlap_upper_limit'] = int(self.max_entry.text())
def set_increment(self):
LOG.debug("Value of increment: " + str(self.step_entry.text()))
self.parameters['360overlap_increment'] = int(self.step_entry.text())
def set_axis_checkbox(self):
LOG.debug("Is rotation axis on left-hand-side?: " + str(self.axis_on_left.isChecked()))
self.parameters['360overlap_axis_on_left'] = bool(self.axis_on_left.isChecked())
def overlap_button_pressed(self):
LOG.debug("Find overlap button pressed")
if os.path.exists(self.parameters['360overlap_output_dir']) or \
os.path.exists(self.parameters['360overlap_temp_dir']):
qm = QMessageBox()
rep = qm.question(self, '', "Output directory or(and) temporary dir exist. Can I delete both?", qm.Yes | qm.No)
if rep == qm.Yes:
try:
rmtree(self.parameters['360overlap_output_dir'])
rmtree(self.parameters['360overlap_temp_dir'])
except:
pass
else:
return
os.makedirs(self.parameters['360overlap_temp_dir'])
os.makedirs(self.parameters['360overlap_output_dir'])
find_overlap(self.parameters)
if os.path.exists(self.parameters['360overlap_output_dir']):
params_file_path = os.path.join(self.parameters['360overlap_output_dir'], '360_overlap_params.yaml')
params.save_parameters(self.parameters, params_file_path)
def help_button_pressed(self):
LOG.debug("Help button pressed")
h = "This script takes as input a CT scan that has been collected in 'half-acquisition' mode"
h += " and produces a series of reconstructed slices, each of which are generated by cropping and"
h += " concatenating opposing projections together over a range of 'overlap' values (i.e. the pixel column"
h += " at which the images are cropped and concatenated)."
h += " The objective is to review this series of images to determine the pixel column at which the axis of rotation"
h += " is located (much like the axis search function commonly used in reconstruction software)."
QMessageBox.information(self, "Help", h)
def import_parameters_button_pressed(self):
LOG.debug("Import params button clicked")
dir_explore = QFileDialog(self)
params_file_path = dir_explore.getOpenFileName(filter="*.yaml")
try:
file_in = open(params_file_path[0], 'r')
new_parameters = yaml.load(file_in, Loader=yaml.FullLoader)
if self.update_parameters(new_parameters) == 0:
print("Parameters file loaded from: " + str(params_file_path[0]))
except FileNotFoundError:
print("You need to select a valid input file")
def save_parameters_button_pressed(self):
LOG.debug("Save params button clicked")
dir_explore = QFileDialog(self)
params_file_path = dir_explore.getSaveFileName(filter="*.yaml")
garbage, file_name = os.path.split(params_file_path[0])
file_extension = os.path.splitext(file_name)
# If the user doesn't enter the .yaml extension then append it to filepath
if file_extension[-1] == "":
file_path = params_file_path[0] + ".yaml"
else:
file_path = params_file_path[0]
try:
file_out = open(file_path, 'w')
yaml.dump(self.parameters, file_out)
print("Parameters file saved at: " + str(file_path))
except FileNotFoundError:
print("You need to select a directory and use a valid file name")
tofu-0.12.0/tofu/ez/GUI/Stitch_tools_tab/ezmview_qt.py 0000664 0000000 0000000 00000024336 14237137211 0022647 0 ustar 00root root 0000000 0000000 import os
import logging
from PyQt5.QtWidgets import (
QGroupBox,
QPushButton,
QLineEdit,
QLabel,
QCheckBox,
QGridLayout,
QFileDialog,
QMessageBox,
)
import yaml
from tofu.ez.Helpers.mview_main import main_prep
LOG = logging.getLogger(__name__)
class EZMViewGroup(QGroupBox):
def __init__(self):
super().__init__()
self.args = {}
self.e_indir = ""
self.e_nproj = 0
self.e_nflats = 0
self.e_ndarks = 0
self.e_nviews = 0
self.e_noflats2 = False
self.e_Andor = False
self.setTitle("Split a sequence of tif files over flats/darks/tomo directories")
self.setStyleSheet("QGroupBox {color: green;}")
self.input_dir_button = QPushButton()
self.input_dir_button.setText("Select directory with a CT sequence")
self.input_dir_button.clicked.connect(self.select_directory)
self.input_dir_entry = QLineEdit()
self.input_dir_entry.editingFinished.connect(self.set_directory_entry)
self.num_projections_label = QLabel()
self.num_projections_label.setText("Number of projections")
self.num_projections_entry = QLineEdit()
self.num_projections_entry.editingFinished.connect(self.set_num_projections)
self.num_flats_label = QLabel()
self.num_flats_label.setText("Number of flats")
self.num_flats_entry = QLineEdit()
self.num_flats_entry.editingFinished.connect(self.set_num_flats)
self.num_darks_label = QLabel()
self.num_darks_label.setText("Number of darks")
self.num_darks_entry = QLineEdit()
self.num_darks_entry.editingFinished.connect(self.set_num_darks)
self.num_vert_steps_label = QLabel()
self.num_vert_steps_label.setText("Number of CT sets in the sequence")
self.num_vert_steps_entry = QLineEdit()
self.num_vert_steps_entry.editingFinished.connect(self.set_num_steps)
self.no_trailing_flats_darks_checkbox = QCheckBox()
self.no_trailing_flats_darks_checkbox.setText("No trailing flats/darks")
self.no_trailing_flats_darks_checkbox.stateChanged.connect(self.set_trailing_checkbox)
self.filenames_without_padding_checkbox = QCheckBox()
self.filenames_without_padding_checkbox.setText("File names without zero padding")
self.filenames_without_padding_checkbox.stateChanged.connect(self.set_file_names_checkbox)
self.help_button = QPushButton()
self.help_button.setText("Help")
self.help_button.clicked.connect(self.help_button_pressed)
self.undo_button = QPushButton()
self.undo_button.setText("Undo")
self.undo_button.clicked.connect(self.undo_button_pressed)
self.convert_button = QPushButton()
self.convert_button.setText("Convert")
self.convert_button.clicked.connect(self.convert_button_pressed)
self.convert_button.setStyleSheet("color:royalblue;font-weight:bold")
self.save_parameters_button = QPushButton("Save Parameters to File")
self.save_parameters_button.clicked.connect(self.save_parameters_button_pressed)
self.import_parameters_button = QPushButton("Import Parameters from File")
self.import_parameters_button.clicked.connect(self.import_parameters_button_pressed)
self.set_layout()
def set_layout(self):
layout = QGridLayout()
layout.addWidget(self.input_dir_button, 0, 0, 1, 3)
layout.addWidget(self.input_dir_entry, 1, 0, 1, 3)
layout.addWidget(self.num_projections_label, 2, 0)
layout.addWidget(self.num_projections_entry, 2, 1, 1, 2)
layout.addWidget(self.num_flats_label, 3, 0)
layout.addWidget(self.num_flats_entry, 3, 1, 1, 2)
layout.addWidget(self.num_darks_label, 4, 0)
layout.addWidget(self.num_darks_entry, 4, 1, 1, 2)
layout.addWidget(self.num_vert_steps_label, 5, 0)
layout.addWidget(self.num_vert_steps_entry, 5, 1, 1, 2)
layout.addWidget(self.no_trailing_flats_darks_checkbox, 6, 0)
layout.addWidget(self.filenames_without_padding_checkbox, 6, 1, 1, 2)
layout.addWidget(self.help_button, 7, 0, 1, 1)
layout.addWidget(self.undo_button, 7, 1, 1, 1)
layout.addWidget(self.convert_button, 7, 2, 1, 1)
layout.addWidget(self.import_parameters_button, 8, 0, 1, 2)
layout.addWidget(self.save_parameters_button, 8, 2, 1, 1)
self.setLayout(layout)
def init_values(self):
self.parameters = {'parameters_type': 'ez_mview'}
self.input_dir_entry.setText(os.getcwd())
self.parameters['ezmview_input_dir'] = os.getcwd()
self.num_projections_entry.setText("3000")
self.parameters['ezmview_num_projections'] = 3000
self.num_flats_entry.setText("10")
self.parameters['ezmview_num_flats'] = 10
self.num_darks_entry.setText("10")
self.parameters['ezmview_num_darks'] = 10
self.num_vert_steps_entry.setText("1")
self.parameters['ezmview_num_vertical_steps'] = 1
self.no_trailing_flats_darks_checkbox.setChecked(False)
self.parameters['ezmview_flats2'] = False
self.filenames_without_padding_checkbox.setChecked(False)
self.parameters['ezmview_zero_padding'] = False
def update_parameters(self, new_parameters):
LOG.debug("Update parameters")
if new_parameters['parameters_type'] != 'ez_mview':
print("Error: Invalid parameter file type: " + str(new_parameters['parameters_type']))
return -1
# Update parameters dictionary (which is passed to auto_stitch_funcs)
self.parameters = new_parameters
# Update displayed parameters for GUI
self.input_dir_entry.setText(str(self.parameters['ezmview_input_dir']))
self.num_projections_entry.setText(str(self.parameters['ezmview_num_projections']))
self.num_flats_entry.setText(str(self.parameters['ezmview_num_flats']))
self.num_darks_entry.setText(str(self.parameters['ezmview_num_darks']))
self.num_vert_steps_entry.setText(str(self.parameters['ezmview_num_vertical_steps']))
self.no_trailing_flats_darks_checkbox.setChecked(bool(self.parameters['ezmview_flats2']))
self.filenames_without_padding_checkbox.setChecked(bool(self.parameters['ezmview_zero_padding']))
def select_directory(self):
LOG.debug("Select directory button pressed")
dir_explore = QFileDialog(self)
directory = dir_explore.getExistingDirectory()
self.input_dir_entry.setText(directory)
self.parameters['ezmview_input_dir'] = directory
def set_directory_entry(self):
LOG.debug("Directory entry: " + str(self.input_dir_entry.text()))
self.parameters['ezmview_input_dir'] = str(self.input_dir_entry.text())
def set_num_projections(self):
LOG.debug("Num projections: " + str(self.num_projections_entry.text()))
self.parameters['ezmview_num_projections'] = int(self.num_projections_entry.text())
def set_num_flats(self):
LOG.debug("Num flats: " + str(self.num_flats_entry.text()))
self.parameters['ezmview_num_flats'] = int(self.num_flats_entry.text())
def set_num_darks(self):
LOG.debug("Num darks: " + str(self.num_darks_entry.text()))
self.parameters['ezmview_num_darks'] = int(self.num_darks_entry.text())
def set_num_steps(self):
LOG.debug("Num steps: " + str(self.num_vert_steps_entry.text()))
self.parameters['ezmview_num_vertical_steps'] = int(self.num_vert_steps_entry.text())
def set_trailing_checkbox(self):
LOG.debug("No trailing: " + str(self.no_trailing_flats_darks_checkbox.isChecked()))
self.parameters['ezmview_flats2'] = bool(self.no_trailing_flats_darks_checkbox.isChecked())
def set_file_names_checkbox(self):
LOG.debug("File names without zero padding: " +
str(self.filenames_without_padding_checkbox.isChecked()))
self.parameters['ezmview_zero_padding'] = \
bool(self.filenames_without_padding_checkbox.isChecked())
def convert_button_pressed(self):
LOG.debug("Convert button pressed")
LOG.debug(self.parameters)
main_prep(self.parameters)
def undo_button_pressed(self):
LOG.debug("Undo button pressed")
cmd = "find {} -type f -name \"*.tif\" -exec mv -t {} {{}} +"
cmd = cmd.format(str(self.parameters['ezmview_input_dir']), str(self.parameters['ezmview_input_dir']))
os.system(cmd)
def help_button_pressed(self):
LOG.debug("Help button pressed")
h = "Distributes a sequence of CT frames in flats/darks/tomo/flats2 directories\n"
h += "assuming that acqusition sequence is flats->darks->tomo->flats2\n"
h += 'Use only for sequences with flat fields acquired at 0 and 180!\n'
h += "Conversions happens in-place but can be undone"
QMessageBox.information(self, "Help", h)
def import_parameters_button_pressed(self):
LOG.debug("Import params button clicked")
dir_explore = QFileDialog(self)
params_file_path = dir_explore.getOpenFileName(filter="*.yaml")
try:
file_in = open(params_file_path[0], 'r')
new_parameters = yaml.load(file_in, Loader=yaml.FullLoader)
if self.update_parameters(new_parameters) == 0:
print("Parameters file loaded from: " + str(params_file_path[0]))
except FileNotFoundError:
print("You need to select a valid input file")
def save_parameters_button_pressed(self):
LOG.debug("Save params button clicked")
dir_explore = QFileDialog(self)
params_file_path = dir_explore.getSaveFileName(filter="*.yaml")
garbage, file_name = os.path.split(params_file_path[0])
file_extension = os.path.splitext(file_name)
# If the user doesn't enter the .yaml extension then append it to filepath
if file_extension[-1] == "":
file_path = params_file_path[0] + ".yaml"
else:
file_path = params_file_path[0]
try:
file_out = open(file_path, 'w')
yaml.dump(self.parameters, file_out)
print("Parameters file saved at: " + str(file_path))
except FileNotFoundError:
print("You need to select a directory and use a valid file name")
tofu-0.12.0/tofu/ez/GUI/Stitch_tools_tab/ezstitch_qt.py 0000664 0000000 0000000 00000052722 14237137211 0023016 0 ustar 00root root 0000000 0000000 import os
from PyQt5.QtWidgets import (
QGroupBox,
QPushButton,
QCheckBox,
QLabel,
QLineEdit,
QGridLayout,
QVBoxLayout,
QHBoxLayout,
QRadioButton,
QFileDialog,
QMessageBox,
)
from shutil import rmtree
import logging
import getpass
import yaml
import tofu.ez.params as params
from tofu.ez.Helpers.stitch_funcs import main_sti_mp, main_conc_mp, main_360_mp_depth1
from tofu.ez.GUI.message_dialog import warning_message
LOG = logging.getLogger(__name__)
class EZStitchGroup(QGroupBox):
def __init__(self):
super().__init__()
self.setTitle("Vertical stitching and reslicing tool")
self.setStyleSheet('QGroupBox {color: purple;}')
self.input_dir_button = QPushButton()
self.input_dir_button.setText("Select input directory")
self.input_dir_button.setToolTip("Normally contains a bunch of directories at the first depth level\n" \
"each of which has a subdirectory with the same name (second depth level). \n"
"Images in these second-level subdirectories will be stitched together.")
self.input_dir_button.clicked.connect(self.input_button_pressed)
self.input_dir_entry = QLineEdit()
self.input_dir_entry.editingFinished.connect(self.set_input_entry)
self.tmp_dir_button = QPushButton()
self.tmp_dir_button.setText("Select temporary directory")
self.tmp_dir_button.clicked.connect(self.temp_button_pressed)
self.tmp_dir_entry = QLineEdit()
self.tmp_dir_entry.editingFinished.connect(self.set_temp_entry)
self.output_dir_button = QPushButton()
self.output_dir_button.setText("Directory to save stitched images")
self.output_dir_button.clicked.connect(self.output_button_pressed)
self.output_dir_entry = QLineEdit()
self.output_dir_entry.editingFinished.connect(self.set_output_entry)
self.types_of_images_label = QLabel()
tmpstr = "Name of subdirectories which contain the same type of images in every directory in the input"
self.types_of_images_label.setToolTip(tmpstr)
self.types_of_images_label.setText("Name of subdirectories with the same type of images to stitch(e.g. sli, tomo, proj-pr, etc.)")
self.types_of_images_entry = QLineEdit()
self.types_of_images_entry.setToolTip(tmpstr)
self.types_of_images_entry.editingFinished.connect(self.set_type_images)
self.orthogonal_checkbox = QCheckBox()
self.orthogonal_checkbox.setText("Stitch orthogonal sections (will reslice images in every subdirectory and then stitch)")
self.orthogonal_checkbox.stateChanged.connect(self.set_stitch_checkbox)
self.start_stop_step_label = QLabel()
self.start_stop_step_label.setText("Which images to be stitched: start,stop,step:")
self.start_stop_step_entry = QLineEdit()
self.start_stop_step_entry.editingFinished.connect(self.set_start_stop_step)
self.sample_moved_down_checkbox = QCheckBox()
self.sample_moved_down_checkbox.setText("Sample was moved downwards during scan")
self.sample_moved_down_checkbox.stateChanged.connect(self.set_sample_moved_down)
self.interpolate_regions_rButton = QRadioButton()
self.interpolate_regions_rButton.setText("Interpolate overlapping regions and equalize intensity")
self.interpolate_regions_rButton.clicked.connect(self.set_rButton)
self.num_overlaps_label = QLabel()
self.num_overlaps_label.setText("Number of overlapping rows")
self.num_overlaps_entry = QLineEdit()
self.num_overlaps_entry.editingFinished.connect(self.set_overlap)
self.clip_histogram_checkbox = QCheckBox()
self.clip_histogram_checkbox.setText("Clip histogram and convert slices to 8-bit before saving")
self.clip_histogram_checkbox.stateChanged.connect(self.set_histogram_checkbox)
self.min_value_label = QLabel()
self.min_value_label.setText("Min value in 32-bit histogram")
self.min_value_entry = QLineEdit()
self.min_value_entry.editingFinished.connect(self.set_min_value)
self.max_value_label = QLabel()
self.max_value_label.setText("Max value in 32-bit histogram")
self.max_value_entry = QLineEdit()
self.max_value_entry.editingFinished.connect(self.set_max_value)
self.concatenate_rButton = QRadioButton()
self.concatenate_rButton.setText("Concatenate only")
self.concatenate_rButton.clicked.connect(self.set_rButton)
self.first_row_label = QLabel()
self.first_row_label.setText("First row")
self.first_row_entry = QLineEdit()
self.first_row_entry.editingFinished.connect(self.set_first_row)
self.last_row_label = QLabel()
self.last_row_label.setText("Last row")
self.last_row_entry = QLineEdit()
self.last_row_entry.editingFinished.connect(self.set_last_row)
self.half_acquisition_rButton = QRadioButton()
self.half_acquisition_rButton.setText("Horizontal stitching of half-acq. mode data (applies to tif images in the Input)")
#self.half_acquisition_rButtonYfor a half-acqusition mode data (even number of tif files in the Input directory)")
self.half_acquisition_rButton.clicked.connect(self.set_rButton)
self.column_of_axis_label = QLabel()
self.column_of_axis_label.setText("In which column the axis of rotation is")
self.column_of_axis_entry = QLineEdit()
self.column_of_axis_entry.editingFinished.connect(self.set_axis_column)
self.help_button = QPushButton()
self.help_button.setText("Help")
self.help_button.clicked.connect(self.help_button_pressed)
self.delete_button = QPushButton()
self.delete_button.setText("Delete output dir")
self.delete_button.clicked.connect(self.delete_button_pressed)
self.stitch_button = QPushButton()
self.stitch_button.setText("Stitch")
self.stitch_button.clicked.connect(self.stitch_button_pressed)
self.stitch_button.setStyleSheet("color:royalblue;font-weight:bold")
self.import_parameters_button = QPushButton("Import Parameters from File")
self.import_parameters_button.clicked.connect(self.import_parameters_button_pressed)
self.save_parameters_button = QPushButton("Save Parameters to File")
self.save_parameters_button.clicked.connect(self.save_parameters_button_pressed)
self.set_layout()
def set_layout(self):
layout = QGridLayout()
vbox1 = QVBoxLayout()
vbox1.addWidget(self.input_dir_button)
vbox1.addWidget(self.input_dir_entry)
vbox1.addWidget(self.tmp_dir_button)
vbox1.addWidget(self.tmp_dir_entry)
vbox1.addWidget(self.output_dir_button)
vbox1.addWidget(self.output_dir_entry)
layout.addItem(vbox1, 0, 0)
grid = QGridLayout()
grid.addWidget(self.types_of_images_label, 0, 0)
grid.addWidget(self.types_of_images_entry, 0, 1)
grid.addWidget(self.orthogonal_checkbox, 1, 0, 1, 2)
grid.addWidget(self.start_stop_step_label, 2, 0)
grid.addWidget(self.start_stop_step_entry, 2, 1)
grid.addWidget(self.sample_moved_down_checkbox, 3, 0)
grid.addWidget(self.interpolate_regions_rButton, 4, 0, 1, 2)
grid.addWidget(self.num_overlaps_label, 5, 0)
grid.addWidget(self.num_overlaps_entry, 5, 1)
grid.addWidget(self.clip_histogram_checkbox, 6, 0)
grid.addWidget(self.min_value_label, 7, 0)
grid.addWidget(self.min_value_entry, 7, 1)
grid.addWidget(self.max_value_label, 8, 0)
grid.addWidget(self.max_value_entry, 8, 1)
layout.addItem(grid, 1, 0)
grid2 = QGridLayout()
grid2.addWidget(self.concatenate_rButton, 0, 0, 1, 2)
grid2.addWidget(self.first_row_label, 1, 0)
grid2.addWidget(self.first_row_entry, 1, 1)
grid2.addWidget(self.last_row_label, 1, 2)
grid2.addWidget(self.last_row_entry, 1, 3)
layout.addItem(grid2, 2, 0)
grid3 = QGridLayout()
grid3.addWidget(self.half_acquisition_rButton, 0, 0, 1, 2)
grid3.addWidget(self.column_of_axis_label, 1, 0)
grid3.addWidget(self.column_of_axis_entry, 1, 1)
layout.addItem(grid3, 3, 0)
grid4 = QGridLayout()
grid4.addWidget(self.help_button, 0, 0)
grid4.addWidget(self.delete_button, 0, 1)
grid4.addWidget(self.stitch_button, 0, 2)
grid4.addWidget(self.import_parameters_button, 1, 0, 1, 2)
grid4.addWidget(self.save_parameters_button, 1, 2)
layout.addItem(grid4, 4, 0)
self.setLayout(layout)
def init_values(self):
self.parameters = {'parameters_type': 'ez_stitch'}
self.parameters['ezstitch_input_dir'] = os.path.expanduser('~')
self.input_dir_entry.setText(self.parameters['ezstitch_input_dir'])
self.parameters['ezstitch_temp_dir'] = os.path.join(
os.path.expanduser('~'), "tmp-ezstitch")
self.tmp_dir_entry.setText(self.parameters['ezstitch_temp_dir'])
self.parameters['ezstitch_output_dir'] = os.path.join(
os.path.expanduser('~'), "ezufo-stitched-images")
self.output_dir_entry.setText(self.parameters['ezstitch_output_dir'])
self.parameters['ezstitch_type_image'] = "sli"
self.types_of_images_entry.setText(self.parameters['ezstitch_type_image'])
self.parameters['ezstitch_stitch_orthogonal'] = True
self.orthogonal_checkbox.setChecked(self.parameters['ezstitch_stitch_orthogonal'])
self.parameters['ezstitch_start_stop_step'] = "200,2000,200"
self.start_stop_step_entry.setText(self.parameters['ezstitch_start_stop_step'])
self.parameters['ezstitch_sample_moved_down'] = False
self.sample_moved_down_checkbox.setChecked(self.parameters['ezstitch_sample_moved_down'])
self.parameters['ezstitch_stitch_type'] = 0
self.interpolate_regions_rButton.setChecked(True)
self.concatenate_rButton.setChecked(False)
self.half_acquisition_rButton.setChecked(False)
self.parameters['ezstitch_num_overlap_rows'] = 60
self.num_overlaps_entry.setText(str(self.parameters['ezstitch_num_overlap_rows']))
self.parameters['ezstitch_clip_histo'] = False
self.clip_histogram_checkbox.setChecked(self.parameters['ezstitch_clip_histo'])
self.parameters['ezstitch_histo_min'] = -0.0003
self.min_value_entry.setText(str(self.parameters['ezstitch_histo_min']))
self.parameters['ezstitch_histo_max'] = 0.0002
self.max_value_entry.setText(str(self.parameters['ezstitch_histo_max']))
self.parameters['ezstitch_first_row'] = 40
self.first_row_entry.setText(str(self.parameters['ezstitch_first_row']))
self.parameters['ezstitch_last_row'] = 440
self.last_row_entry.setText(str(self.parameters['ezstitch_last_row']))
self.parameters['ezstitch_axis_of_rotation'] = 245
self.column_of_axis_entry.setText(str(self.parameters['ezstitch_axis_of_rotation']))
def update_parameters(self, new_parameters):
LOG.debug("Update parameters")
if new_parameters['parameters_type'] != 'ez_stitch':
print("Error: Invalid parameter file type: " + str(new_parameters['parameters_type']))
return -1
# Update parameters dictionary (which is passed to auto_stitch_funcs)
self.parameters = new_parameters
# Update displayed parameters for GUI
self.input_dir_entry.setText(self.parameters['ezstitch_input_dir'])
self.tmp_dir_entry.setText(self.parameters['ezstitch_temp_dir'])
self.output_dir_entry.setText(self.parameters['ezstitch_output_dir'])
self.types_of_images_entry.setText(self.parameters['ezstitch_type_image'])
self.orthogonal_checkbox.setChecked(self.parameters['ezstitch_stitch_orthogonal'])
self.start_stop_step_entry.setText(self.parameters['ezstitch_start_stop_step'])
self.sample_moved_down_checkbox.setChecked(self.parameters['ezstitch_sample_moved_down'])
if self.parameters['ezstitch_stitch_type'] == 0:
self.interpolate_regions_rButton.setChecked(True)
elif self.parameters['ezstitch_stitch_type'] == 1:
self.concatenate_rButton.setChecked(True)
elif self.parameters['ezstitch_stitch_type'] == 2:
self.half_acquisition_rButton.setChecked(True)
self.num_overlaps_entry.setText(str(self.parameters['ezstitch_num_overlap_rows']))
self.clip_histogram_checkbox.setChecked(self.parameters['ezstitch_clip_histo'])
self.min_value_entry.setText(str(self.parameters['ezstitch_histo_min']))
self.max_value_entry.setText(str(self.parameters['ezstitch_histo_max']))
self.first_row_entry.setText(str(self.parameters['ezstitch_first_row']))
self.last_row_entry.setText(str(self.parameters['ezstitch_last_row']))
self.column_of_axis_entry.setText(str(self.parameters['ezstitch_axis_of_rotation']))
def set_rButton(self):
if self.interpolate_regions_rButton.isChecked():
LOG.debug("Interpolate regions")
self.parameters['ezstitch_stitch_type'] = 0
elif self.concatenate_rButton.isChecked():
LOG.debug("Concatenate only")
self.parameters['ezstitch_stitch_type'] = 1
elif self.half_acquisition_rButton.isChecked():
LOG.debug("Half-acquisition mode")
self.parameters['ezstitch_stitch_type'] = 2
def input_button_pressed(self):
LOG.debug("Input button pressed")
dir_explore = QFileDialog(self)
self.parameters['ezstitch_input_dir'] = dir_explore.getExistingDirectory()
self.input_dir_entry.setText(self.parameters['ezstitch_input_dir'])
def set_input_entry(self):
LOG.debug("Input: " + str(self.input_dir_entry.text()))
self.parameters['ezstitch_input_dir'] = str(self.input_dir_entry.text())
def temp_button_pressed(self):
LOG.debug("Temp button pressed")
dir_explore = QFileDialog(self)
self.parameters['ezstitch_temp_dir'] = dir_explore.getExistingDirectory()
self.tmp_dir_entry.setText(self.parameters['ezstitch_temp_dir'])
def set_temp_entry(self):
LOG.debug("Temp: " + str(self.tmp_dir_entry.text()))
self.parameters['ezstitch_temp_dir'] = str(self.tmp_dir_entry.text())
def output_button_pressed(self):
LOG.debug("Output button pressed")
dir_explore = QFileDialog(self)
self.parameters['ezstitch_output_dir'] = dir_explore.getExistingDirectory()
self.output_dir_entry.setText(self.parameters['ezstitch_output_dir'])
def set_output_entry(self):
LOG.debug("Output: " + str(self.output_dir_entry.text()))
self.parameters['ezstitch_output_dir'] = str(self.output_dir_entry.text())
def set_type_images(self):
LOG.debug("Type of images: " + str(self.types_of_images_entry.text()))
self.parameters['ezstitch_type_image'] = str(self.types_of_images_entry.text())
def set_stitch_checkbox(self):
LOG.debug("Stitch orthogonal: " + str(self.orthogonal_checkbox.isChecked()))
self.parameters['ezstitch_stitch_orthogonal'] = bool(self.orthogonal_checkbox.isChecked())
def set_start_stop_step(self):
LOG.debug("Images to be stitched: " + str(self.start_stop_step_entry.text()))
self.parameters['ezstitch_start_stop_step'] = str(self.start_stop_step_entry.text())
def set_sample_moved_down(self):
LOG.debug("Sample moved down: " + str(self.sample_moved_down_checkbox.isChecked()))
self.parameters['ezstitch_sample_moved_down'] = bool(self.sample_moved_down_checkbox.isChecked())
def set_overlap(self):
LOG.debug("Num overlapping rows: " + str(self.num_overlaps_entry.text()))
self.parameters['ezstitch_num_overlap_rows'] = int(self.num_overlaps_entry.text())
def set_histogram_checkbox(self):
LOG.debug("Clip histogram: " + str(self.clip_histogram_checkbox.isChecked()))
self.parameters['ezstitch_clip_histo'] = bool(self.clip_histogram_checkbox.isChecked())
def set_min_value(self):
LOG.debug("Min value: " + str(self.min_value_entry.text()))
self.parameters['ezstitch_histo_min'] = float(self.min_value_entry.text())
def set_max_value(self):
LOG.debug("Max value: " + str(self.max_value_entry.text()))
self.parameters['ezstitch_histo_max'] = float(self.max_value_entry.text())
def set_first_row(self):
LOG.debug("First row: " + str(self.first_row_entry.text()))
self.parameters['ezstitch_first_row'] = int(self.first_row_entry.text())
def set_last_row(self):
LOG.debug("Last row: " + str(self.last_row_entry.text()))
self.parameters['ezstitch_last_row'] = int(self.last_row_entry.text())
def set_axis_column(self):
LOG.debug("Column of axis: " + str(self.column_of_axis_entry.text()))
self.parameters['ezstitch_axis_of_rotation'] = int(self.column_of_axis_entry.text())
def help_button_pressed(self):
LOG.debug("Help button pressed")
h = "Stitches images vertically\n"
h += "Directory structure is, f.i., Input/000, Input/001,...Input/00N\n"
h += "Each 000, 001, ... 00N directory must have identical subdirectory \"Type\"\n"
h += "Selected range of images from \"Type\" directory will be stitched vertically\n"
h += "across all subdirectories in the Input directory"
h += "to be added as options:\n"
h += "(1) orthogonal reslicing, (2) interpolation, (3) horizontal stitching"
QMessageBox.information(self, "Help", h)
def delete_button_pressed(self):
LOG.debug("Delete button pressed")
# if os.path.exists(self.parameters['ezstitch_output_dir']):
# os.system('rm -r {}'.format(self.parameters['ezstitch_output_dir']))
# print(" - Directory with reconstructed data was removed")
if os.path.exists(self.parameters['ezstitch_output_dir']):
qm = QMessageBox()
rep = qm.question(self, '', f"{self.parameters['ezstitch_output_dir']} \n"
"will be removed. Continue?", qm.Yes | qm.No)
if rep == qm.Yes:
try:
rmtree(self.parameters['ezstitch_output_dir'])
except:
warning_message('Error while deleting directory')
return
else:
return
def stitch_button_pressed(self):
LOG.debug("Stitch button pressed")
if os.path.exists(self.parameters['ezstitch_temp_dir']):
qm = QMessageBox()
rep = qm.question(self, '', "Temporary dir is not empty. Is it safe to delete it?", qm.Yes | qm.No)
if rep == qm.Yes:
try:
rmtree(self.parameters['ezstitch_temp_dir'])
except:
warning_message('Error while deleting directory')
return
else:
return
if os.path.exists(self.parameters['ezstitch_output_dir']):
#raise ValueError('Output directory exists. Delete it or select another one.')
warning_message('Output directory exists. Delete it or select another one.')
return
print("======= Begin Stitching =======")
# Interpolate overlapping regions and equalize intensity
if self.parameters['ezstitch_stitch_type'] == 0:
main_sti_mp(self.parameters)
# Concatenate only
elif self.parameters['ezstitch_stitch_type'] == 1:
main_conc_mp(self.parameters)
# Half acquisition mode
elif self.parameters['ezstitch_stitch_type'] == 2:
main_360_mp_depth1(self.parameters['ezstitch_input_dir'],
self.parameters['ezstitch_output_dir'],
self.parameters['ezstitch_axis_of_rotation'], 0)
if os.path.isdir(self.parameters['ezstitch_output_dir']):
params_file_path = os.path.join(self.parameters['ezstitch_output_dir'], 'ezmview_params.yaml')
params.save_parameters(self.parameters, params_file_path)
print("==== Waiting for Next Task ====")
def import_parameters_button_pressed(self):
LOG.debug("Import params button clicked")
dir_explore = QFileDialog(self)
params_file_path = dir_explore.getOpenFileName(filter="*.yaml")
try:
file_in = open(params_file_path[0], 'r')
new_parameters = yaml.load(file_in, Loader=yaml.FullLoader)
if self.update_parameters(new_parameters) == 0:
print("Parameters file loaded from: " + str(params_file_path[0]))
except FileNotFoundError:
print("You need to select a valid input file")
def save_parameters_button_pressed(self):
LOG.debug("Save params button clicked")
dir_explore = QFileDialog(self)
params_file_path = dir_explore.getSaveFileName(filter="*.yaml")
garbage, file_name = os.path.split(params_file_path[0])
file_extension = os.path.splitext(file_name)
# If the user doesn't enter the .yaml extension then append it to filepath
if file_extension[-1] == "":
file_path = params_file_path[0] + ".yaml"
else:
file_path = params_file_path[0]
try:
file_out = open(file_path, 'w')
yaml.dump(self.parameters, file_out)
print("Parameters file saved at: " + str(file_path))
except FileNotFoundError:
print("You need to select a directory and use a valid file name")
tofu-0.12.0/tofu/ez/GUI/__init__.py 0000664 0000000 0000000 00000000001 14237137211 0016667 0 ustar 00root root 0000000 0000000
tofu-0.12.0/tofu/ez/GUI/default_settings.yaml 0000664 0000000 0000000 00000006034 14237137211 0021022 0 ustar 00root root 0000000 0000000 # Default configuration file for ez_ufo_qt
# Modify at your own peril
---
main_config_input_dir: ""
main_config_temp_dir: ""
main_config_output_dir: ""
main_config_darks_dir_name: "darks"
main_config_flats_dir_name: "flats"
main_config_tomo_dir_name: "tomo"
main_config_flats2_dir_name: "flats2"
main_config_save_multipage_tiff: false
main_cor_axis_search_method: 1
main_cor_axis_search_interval: "1010,1030,0.5"
main_cor_search_row_start: 100
main_cor_recon_patch_size: 256
main_cor_axis_column: 0.0
main_cor_axis_increment_step: 0.0
main_filters_remove_spots: false
main_filters_remove_spots_threshold: 1000
main_filters_remove_spots_blur_sigma: 2
main_filters_ring_removal: false
main_filters_ring_removal_ufo_lpf: true
main_filters_ring_removal_ufo_lpf_1d_or_2d: true
main_filters_ring_removal_ufo_lpf_sigma_horizontal: 3
main_filters_ring_removal_ufo_lpf_sigma_vertical: 1
main_filters_ring_removal_sarepy_window_size: 21
main_filters_ring_removal_sarepy_wide: false
main_filters_ring_removal_sarepy_window: 91
main_filters_ring_removal_sarepy_SNR: 3
main_pr_phase_retrieval: false
main_pr_photon_energy: 20
main_pr_pixel_size: 3.6
main_pr_detector_distance: 0.1
main_pr_delta_beta_ratio: 200
main_region_select_rows: false
main_region_first_row: 100
main_region_number_rows: 200
main_region_nth_row: 20
main_region_clip_histogram: false
main_region_bit_depth: 8
main_region_histogram_min: 0.0
main_region_histogram_max: 0.0
main_config_preprocess: false
main_config_preprocess_command: "remove-outliers size=3 threshold=500 sign=1"
main_region_rotate_volume_clock: 0.0
main_region_crop_slices: false
main_region_crop_x: 0
main_region_crop_width: 0
main_region_crop_y: 0
main_region_crop_height: 0
main_config_dry_run: false
main_config_save_params: true
main_config_keep_temp: false
advanced_ffc_sinFFC: false
advanced_ffc_method: 1
advanced_ffc_eigen_pco_reps: 4
advanced_ffc_eigen_pco_downsample: 2
advanced_ffc_downsample: 4
main_config_open_viewer: True
main_config_common_flats_darks: False
main_config_darks_path: "Absolute path to darks"
main_config_flats_path: "Absolute path to flats"
main_config_flats2_checkbox: False
main_config_flats2_path: "Absolute path to flats2"
#NLMDN Settings
advanced_nlmdn_apply_after_reco: False
advanced_nlmdn_input_dir: ""
advanced_nlmdn_input_is_file: False
advanced_nlmdn_output_dir: ""
advanced_nlmdn_save_bigtiff: False
advanced_nlmdn_sim_search_radius: 10
advanced_nlmdn_patch_radius: 3
advanced_nlmdn_smoothing_control: 0.0
advanced_nlmdn_noise_std: 0.0
advanced_nlmdn_window: 0.0
advanced_nlmdn_fast: True
advanced_nlmdn_estimate_sigma: False
advanced_nlmdn_dry_run: False
#ADVANCED TOFU Settings
advanced_advtofu_lamino_angle: 30
advanced_adv_tofu_z_axis_rotation: 360
advanced_advtofu_center_position_z: ""
advanced_advtofu_y_axis_rotation: ""
advanced_advtofu_aux_ffc_dark_scale: ""
advanced_advtofu_aux_ffc_flat_scale: ""
#Optimization Settings
advanced_advtofu_extended_settings: False
advanced_optimize_verbose_console: False
advanced_optimize_slice_mem_coeff: 0.5
advanced_optimize_num_gpus: ""
advanced_optimize_slices_per_device: ""
... tofu-0.12.0/tofu/ez/GUI/ezufo_launcher.py 0000664 0000000 0000000 00000025355 14237137211 0020164 0 ustar 00root root 0000000 0000000 import logging
import os
import sys
from PyQt5 import QtWidgets as qtw
from tofu.ez.GUI.Main.centre_of_rotation import CentreOfRotationGroup
from tofu.ez.GUI.Main.filters import FiltersGroup
from tofu.ez.GUI.Advanced.ffc import FFCGroup
from tofu.ez.GUI.Main.phase_retrieval import PhaseRetrievalGroup
from tofu.ez.GUI.Main.region_and_histogram import ROIandHistGroup
from tofu.ez.GUI.Main.config import ConfigGroup
from tofu.ez.main import clean_tmp_dirs
from tofu.ez.yaml_in_out import Yaml_IO
from tofu.ez.GUI.image_viewer import ImageViewerGroup
import tofu.ez.params as parameters
from tofu.ez.GUI.Advanced.advanced import AdvancedGroup
from tofu.ez.GUI.Advanced.optimization import OptimizationGroup
from tofu.ez.GUI.Advanced.nlmdn import NLMDNGroup
from tofu.ez.GUI.Stitch_tools_tab.ez_360_multi_stitch_qt import MultiStitch360Group
from tofu.ez.GUI.Stitch_tools_tab.ezstitch_qt import EZStitchGroup
from tofu.ez.GUI.Stitch_tools_tab.ezmview_qt import EZMViewGroup
from tofu.ez.GUI.Stitch_tools_tab.ez_360_overlap_qt import Overlap360Group
from tofu.ez.GUI.login_dialog import Login
LOG = logging.getLogger(__name__)
class GUI(qtw.QWidget):
"""
Creates main GUI
"""
def __init__(self, *args, **kwargs):
super(GUI, self).__init__(*args, **kwargs)
self.setWindowTitle("EZ-UFO")
self.setStyleSheet("font: 10pt; font-family: Arial")
# Call login dialog
# self.login_parameters = {}
# QTimer.singleShot(0, self.login)
# Read in default parameter settings from yaml file
try:
settings_path = os.path.dirname(os.path.abspath(__file__)) + "/default_settings.yaml"
self.yaml_io = Yaml_IO()
self.yaml_data = self.yaml_io.read_yaml(settings_path)
parameters.params = dict(self.yaml_data)
parameters.params["parameters_type"] = "ez_ufo_reco"
except FileNotFoundError:
print("Could not load default settings from: " + str(settings_path))
# Initialize tab screen
self.tabs = qtw.QTabWidget()
self.tab1 = qtw.QWidget()
self.tab2 = qtw.QWidget()
self.tab3 = qtw.QWidget()
self.tab4 = qtw.QWidget()
# Create and setup classes for each section of GUI
# Main Tab
self.centre_of_rotation_group = CentreOfRotationGroup()
self.centre_of_rotation_group.init_values()
self.filters_group = FiltersGroup()
self.filters_group.init_values()
self.ffc_group = FFCGroup()
self.ffc_group.init_values()
self.phase_retrieval_group = PhaseRetrievalGroup()
self.phase_retrieval_group.init_values()
self.binning_group = ROIandHistGroup()
self.binning_group.init_values()
self.config_group = ConfigGroup()
self.config_group.init_values()
# Image Viewer
self.image_group = ImageViewerGroup()
# Advanced Tab
self.advanced_group = AdvancedGroup()
self.advanced_group.init_values()
self.optimization_group = OptimizationGroup()
self.optimization_group.init_values()
self.nlmdn_group = NLMDNGroup()
self.nlmdn_group.init_values()
# Stitch_tools_tab Tab
self.multi_stitch_group = MultiStitch360Group()
self.multi_stitch_group.init_values()
self.ezmview_group = EZMViewGroup()
self.ezmview_group.init_values()
self.ezstitch_group = EZStitchGroup()
self.ezstitch_group.init_values()
self.overlap_group = Overlap360Group()
self.overlap_group.init_values()
#######################################################
self.set_layout()
self.resize(0, 0) # window to minimum size
# When new settings are imported signal is sent and this catches it to update params for each GUI object
self.config_group.signal_update_vals_from_params.connect(self.update_values_from_params)
# When RECO is done send signal from config
self.config_group.signal_reco_done.connect(self.switch_to_image_tab)
# To pass directory names from config tab to stitch tab when button pressed
self.multi_stitch_group.get_fdt_names_on_stitch_pressed.connect(self.config_group.set_fdt_names)
self.overlap_group.get_fdt_names_on_stitch_pressed.connect(self.config_group.set_fdt_names)
finish = qtw.QAction("Quit", self)
finish.triggered.connect(self.closeEvent)
self.show()
def set_layout(self):
"""
Set the layout of groups/tabs for the overall application layout
"""
layout = qtw.QVBoxLayout(self)
main_layout = qtw.QGridLayout()
main_layout.addWidget(self.centre_of_rotation_group, 0, 0)
main_layout.addWidget(self.filters_group, 0, 1)
main_layout.addWidget(self.phase_retrieval_group, 1, 0)
main_layout.addWidget(self.binning_group, 1, 1)
main_layout.addWidget(self.config_group, 2, 0, 2, 0)
image_layout = qtw.QGridLayout()
image_layout.addWidget(self.image_group, 0, 0)
advanced_layout = qtw.QGridLayout()
advanced_layout.addWidget(self.ffc_group, 0, 0)
advanced_layout.addWidget(self.advanced_group, 1, 0)
advanced_layout.addWidget(self.optimization_group, 1, 1)
advanced_layout.addWidget(self.nlmdn_group, 2, 0)
helpers_layout = qtw.QGridLayout()
helpers_layout.addWidget(self.ezmview_group, 0, 0)
helpers_layout.addWidget(self.overlap_group, 0, 1)
helpers_layout.addWidget(self.multi_stitch_group, 1, 0)
helpers_layout.addWidget(self.ezstitch_group, 1, 1)
# Add tabs
self.tabs.addTab(self.tab1, "Main")
self.tabs.addTab(self.tab2, "Advanced")
self.tabs.addTab(self.tab3, "Stitching tools")
self.tabs.addTab(self.tab4, "Image Viewer")
# Create main tab
self.tab1.layout = main_layout
self.tab1.setLayout(self.tab1.layout)
# Create image tab
self.tab4.layout = image_layout
self.tab4.setLayout(self.tab4.layout)
# Create advanced tab
self.tab2.layout = advanced_layout
self.tab2.setLayout(self.tab2.layout)
# Create helpers tab
self.tab3.layout = helpers_layout
self.tab3.setLayout(self.tab3.layout)
# Add tabs to widget
layout.addWidget(self.tabs)
self.setLayout(layout)
def update_values_from_params(self):
"""
Updates displayed values when loaded in from external .yaml file of parameters
"""
LOG.debug("Update Values from Params")
LOG.debug(parameters.params)
self.centre_of_rotation_group.set_values_from_params()
self.filters_group.set_values_from_params()
self.ffc_group.set_values_from_params()
self.phase_retrieval_group.set_values_from_params()
self.binning_group.set_values_from_params()
self.config_group.set_values_from_params()
self.nlmdn_group.set_values_from_params()
self.advanced_group.set_values_from_params()
self.optimization_group.set_values_from_params()
def switch_to_image_tab(self):
"""
Function is called after reconstruction
when checkbox "Load images and open viewer after reconstruction" is enabled
Automatically loads images from the output reconstruction directory for viewing
"""
if parameters.params["main_config_open_viewer"] is True:
LOG.debug("Switch to Image Tab")
self.tabs.setCurrentWidget(self.tab2)
if os.path.isdir(str(parameters.params['main_config_output_dir'] + '/sli')):
files = os.listdir(str(parameters.params['main_config_output_dir'] + '/sli'))
#Start thread here to load images
##CHECK IF ONLY SINGLE IMAGE THEN USE OPEN IMAGE -- OTHERWISE OPEN STACK
if len(files) == 1:
print("Only one file in {}: Opening single image {}".
format(parameters.params['main_config_output_dir'] + '/sli', files[0]))
filePath = str(parameters.params['main_config_output_dir'] + '/sli/' + str(files[0]))
self.image_group.open_image_from_filepath(filePath)
else:
print("Multiple files in {}: Opening stack of images".
format(str(parameters.params['main_config_output_dir'] + '/sli')))
self.image_group.open_stack_from_path(
str(parameters.params['main_config_output_dir'] + '/sli'))
else:
print("No output directory found")
def closeEvent(self, event):
"""
Creates verification message box
Cleans up temporary directories when user quits application
"""
logging.debug("QUIT")
reply = qtw.QMessageBox.question(self, 'Quit', 'Are you sure you want to quit?',
qtw.QMessageBox.Yes | qtw.QMessageBox.No, qtw.QMessageBox.No)
if reply == qtw.QMessageBox.Yes:
# remove all directories with projections
clean_tmp_dirs(parameters.params['main_config_temp_dir'], self.config_group.get_fdt_names())
# remove axis-search dir too
tmp = os.path.join(parameters.params['main_config_temp_dir'], 'axis-search')
event.accept()
else:
event.ignore()
def login(self):
login_dialog = Login(self.login_parameters)
if login_dialog.exec_() != qtw.QDialog.Accepted:
self.exit()
else:
#self.file_writer_group.root_dir_entry.setText(self.login_parameters['expdir'])
self.config_group.input_dir_entry.setText(self.login_parameters['expdir'] + "/raw")
self.config_group.set_input_dir()
self.config_group.output_dir_entry.setText(self.login_parameters['expdir'] + "/rec")
self.config_group.set_output_dir()
'''
td = date.today()
tdstr = "{}.{}.{}".format(td.year, td.month, td.day)
logfname = os.path.join(self.login_parameters['expdir'], 'exp-log-' + tdstr + '.log')
if self.login_parameters.has_key('project'):
logfname = os.path.join(self.login_parameters['expdir'], '{}-log-{}-{}.log'.
format(self.login_parameters['project'], self.login_parameters['bl'], tdstr))
try:
open(logfname, 'a').close()
except:
warning_message('Cannot create log file in the selected directory. \n'
'Check permissions and restart.')
self.exit()
'''
def exit(self):
self.close()
def main_qt(args=None):
app = qtw.QApplication(sys.argv)
window = GUI()
sys.exit(app.exec_())
if __name__ == "__main__":
main_qt()
tofu-0.12.0/tofu/ez/GUI/image_viewer.py 0000664 0000000 0000000 00000035752 14237137211 0017620 0 ustar 00root root 0000000 0000000 import os
import logging
import pyqtgraph as pg
import numpy as np
import tifffile
from PyQt5.QtWidgets import (
QPushButton,
QGroupBox,
QLabel,
QDoubleSpinBox,
QRadioButton,
QScrollBar,
QVBoxLayout,
QGridLayout,
QFileDialog,
QMessageBox,
)
from PyQt5.QtCore import Qt
import tofu.ez.image_read_write as image_read_write
#TODO Integrate axis search tab ob tofu gui into this interface
LOG = logging.getLogger(__name__)
class ImageViewerGroup(QGroupBox):
def __init__(self):
super().__init__()
#TODO: initialize on every opening with explicit data type
#mmatching the data format being opened.
#must check that there is enough RAM before loading!!
self.tiff_arr = np.empty([0, 0, 0]) # float32
self.img_arr = np.empty([0, 0])
self.bit_depth = 32
self.open_file_button = QPushButton("Open Image File")
self.open_file_button.clicked.connect(self.open_image_from_file)
self.open_file_button.setStyleSheet("background-color: lightgrey; font: 11pt")
self.open_stack_button = QPushButton("Open Image Stack")
self.open_stack_button.clicked.connect(self.open_stack_from_directory)
self.open_stack_button.setStyleSheet("background-color: lightgrey; font: 11pt")
self.save_file_button = QPushButton("Save Image File")
self.save_file_button.clicked.connect(self.save_image_to_file)
self.save_file_button.setStyleSheet("background-color: lightgrey; font: 11pt")
self.save_stack_button = QPushButton("Save Image Stack")
self.save_stack_button.clicked.connect(self.save_stack_to_directory)
self.save_stack_button.setStyleSheet("background-color: lightgrey; font: 11pt")
self.open_big_tiff_button = QPushButton("Open BigTiff")
self.open_big_tiff_button.clicked.connect(self.open_big_tiff)
self.open_big_tiff_button.setStyleSheet("background-color: lightgrey; font: 11pt")
self.save_big_tiff_button = QPushButton("Save BigTiff")
self.save_big_tiff_button.clicked.connect(self.save_stack_to_big_tiff)
self.save_big_tiff_button.setStyleSheet("background-color: lightgrey; font: 11pt")
self.save_8bit_rButton = QRadioButton()
self.save_8bit_rButton.setText("Save as 8-bit")
self.save_8bit_rButton.clicked.connect(self.set_8bit)
self.save_8bit_rButton.setChecked(False)
self.save_16bit_rButton = QRadioButton()
self.save_16bit_rButton.setText("Save as 16-bit")
self.save_16bit_rButton.clicked.connect(self.set_16bit)
self.save_16bit_rButton.setChecked(False)
self.save_32bit_rButton = QRadioButton()
self.save_32bit_rButton.setText("Save as 32-bit")
self.save_32bit_rButton.clicked.connect(self.set_32bit)
self.save_32bit_rButton.setChecked(True)
self.hist_min_label = QLabel("Histogram Min:")
self.hist_min_label.setAlignment(Qt.AlignRight | Qt.AlignVCenter)
self.hist_max_label = QLabel("Histogram Max:")
self.hist_max_label.setAlignment(Qt.AlignRight | Qt.AlignVCenter)
self.hist_min_input = QDoubleSpinBox()
self.hist_min_input.setDecimals(12)
self.hist_min_input.setRange(-10, 10)
self.hist_min_input.valueChanged.connect(self.min_spin_changed)
self.hist_max_input = QDoubleSpinBox()
self.hist_max_input.setDecimals(12)
self.hist_max_input.setRange(-10, 10)
self.hist_max_input.valueChanged.connect(self.max_spin_changed)
self.apply_histogram_button = QPushButton("Apply Histogram to Image Stack")
self.apply_histogram_button.clicked.connect(self.apply_histogram_button_clicked)
self.image_window = pg.ImageView()
self.image_window.ui.histogram.gradient.hide()
self.histo = self.image_window.getHistogramWidget()
self.scroller = QScrollBar(Qt.Horizontal)
self.scroller.orientation()
self.scroller.setEnabled(False)
self.scroller.valueChanged.connect(self.scroll_changed)
self.set_layout()
def set_layout(self):
vbox = QVBoxLayout()
vbox.addWidget(self.save_8bit_rButton)
vbox.addWidget(self.save_16bit_rButton)
vbox.addWidget(self.save_32bit_rButton)
gridbox = QGridLayout()
gridbox.addWidget(self.hist_max_label, 0, 0)
gridbox.addWidget(self.hist_max_input, 0, 1)
gridbox.addWidget(self.hist_min_label, 1, 0)
gridbox.addWidget(self.hist_min_input, 1, 1)
layout = QGridLayout()
layout.addWidget(self.open_file_button, 0, 0)
layout.addWidget(self.save_file_button, 1, 0)
layout.addWidget(self.open_stack_button, 0, 1)
layout.addWidget(self.save_stack_button, 1, 1)
layout.addWidget(self.open_big_tiff_button, 0, 2)
layout.addWidget(self.save_big_tiff_button, 1, 2)
layout.addItem(vbox, 0, 3, 2, 1)
layout.addItem(gridbox, 0, 4, 2, 1)
layout.addWidget(self.apply_histogram_button, 0, 5)
layout.addWidget(self.image_window, 2, 0, 1, 6)
layout.addWidget(self.scroller, 4, 0, 1, 5)
self.setLayout(layout)
self.resize(640, 480)
self.show()
def scroll_changed(self):
"""
Updated the currently displayed image based on position of scroll bar
:return: None
"""
self.image_window.setImage(self.tiff_arr[self.scroller.value()].T)
def open_image_from_file(self):
"""
Opens and displays a single image (.tif) specified by the user in the file dialog
:return: None
"""
LOG.debug("Open image button pressed")
options = QFileDialog.Options()
filePath, _ = QFileDialog.getOpenFileName(
self, "Open .tif Image File", "", "Tiff Files (*.tif *.tiff)", options=options
)
if filePath:
LOG.debug("Import image path: " + filePath)
self.img_arr = image_read_write.read_image(filePath)
self.image_window.setImage(self.img_arr.T)
self.scroller.setEnabled(False)
def open_image_from_filepath(self, filePath):
"""
Opens and displays a single image (.tif) contained in a directory - (used when one slice is reconstructed)
:param filePath: Full path and filename
:return: None
"""
LOG.debug("Open image from filepath: " + str(filePath))
if filePath:
LOG.debug("Import image path: " + filePath)
self.img_arr = image_read_write.read_image(filePath)
self.image_window.setImage(self.img_arr.T)
self.scroller.setEnabled(False)
def save_image_to_file(self):
"""
Saves the currently displayed image to a file (.tif) specified by the user in the file dialog
:return: None
"""
LOG.debug("Save image to file")
options = QFileDialog.Options()
filepath, _ = QFileDialog.getSaveFileName(
self, "QFileDialog.getSaveFileName()", "", "Tiff Files (*.tif *.tiff)", options=options
)
if filepath:
LOG.debug(filepath)
bit_depth_string = self.check_bit_depth(self.bit_depth)
img = self.image_window.imageItem.qimage
# https://www.programmersought.com/article/73475006380/
size = img.size()
s = img.bits().asstring(
size.width() * size.height() * img.depth() // 8
) # format 0xffRRGGBB
arr = np.fromstring(s, dtype=np.uint8).reshape(
(size.height(), size.width(), img.depth() // 8)
)
image_read_write.write_image(
arr.T[0].T, os.path.dirname(filepath), os.path.basename(filepath), bit_depth_string
)
def open_stack_from_directory(self):
"""
Opens all images (.tif) in a directory and displays them. Allows for scrolling through images with slider
:return: None
"""
LOG.debug("Open image stack button pressed")
dir_explore = QFileDialog()
directory = dir_explore.getExistingDirectory()
if directory:
try:
tiff_list = (".tif", ".tiff")
msg = QMessageBox()
msg.setIcon(QMessageBox.Information)
msg.setWindowTitle("Loading Images...")
msg.setText("Loading Images from Directory")
msg.show()
self.tiff_arr = image_read_write.read_all_images(directory, tiff_list)
self.scroller.setRange(0, self.tiff_arr.shape[0] - 1)
self.scroller.setEnabled(True)
self.image_window.setImage(self.tiff_arr[0].T)
msg.close()
mid_index = self.tiff_arr.shape[0] // 2
self.scroller.setValue(mid_index)
except image_read_write.InvalidDataSetError:
print("Invalid Data Set")
def open_stack_from_path(self, dir_path: str):
"""
Read images (.tif) from directory path into RAM as 3D numpy array
:param dir_path: Path to directory containing multiple .tiff image files
"""
LOG.debug("Open stack from path")
try:
tiff_list = (".tif", ".tiff")
msg = QMessageBox()
msg.setIcon(QMessageBox.Information)
msg.setWindowTitle("Loading Images...")
msg.setText("Loading Images from Directory")
msg.show()
self.tiff_arr = image_read_write.read_all_images(dir_path, tiff_list)
self.scroller.setRange(0, self.tiff_arr.shape[0] - 1)
self.scroller.setEnabled(True)
self.image_window.setImage(self.tiff_arr[0].T)
msg.close()
mid_index = self.tiff_arr.shape[0] // 2
self.scroller.setValue(mid_index)
except image_read_write.InvalidDataSetError:
print("Invalid Data Set")
def save_stack_to_directory(self):
"""
Saves images stored in numpy array to individual files (.tif) in directory specified by user dialog
Saves these images as BigTiff if checkbox is set to True
"""
LOG.debug("Save stack to directory button pressed")
LOG.debug("Saving with bitdepth: " + str(self.bit_depth))
dir_explore = QFileDialog()
directory = dir_explore.getExistingDirectory()
LOG.debug("Writing to directory: " + directory)
if directory:
bit_depth_string = self.check_bit_depth(self.bit_depth)
msg = QMessageBox()
msg.setIcon(QMessageBox.Information)
msg.setWindowTitle("Saving Images...")
msg.setText("Saving Images to Directory")
msg.show()
self.apply_histogram_to_images()
image_read_write.write_all_images(self.tiff_arr, directory, bit_depth_string)
msg.close()
def open_big_tiff(self):
"""
Opens images stored in a big tiff file (.tif) and displays them. Allows user to view them using scrollbar.
:return: None
"""
LOG.debug("Open big tiff button pressed")
options = QFileDialog.Options()
filePath, _ = QFileDialog.getOpenFileName(
self, "QFileDialog.getOpenFileName()", "", "All Files (*)", options=options
)
if filePath:
LOG.debug("Import image path: " + filePath)
msg = QMessageBox()
msg.setIcon(QMessageBox.Information)
msg.setWindowTitle("Loading Images...")
msg.setText("Loading Images from BigTiff")
msg.show()
self.tiff_arr = tifffile.imread(filePath).astype(dtype=np.float32)
self.scroller.setRange(0, self.tiff_arr.shape[0] - 1)
self.scroller.setEnabled(True)
self.image_window.setImage(self.tiff_arr[0].T)
msg.close()
mid_index = self.tiff_arr.shape[0] // 2
self.scroller.setValue(mid_index)
def save_stack_to_big_tiff(self):
"""
Saves the stack of images currently loaded into RAM to a single bigtif file
:return: None
"""
LOG.debug("Save stack to bigtiff button pressed")
LOG.debug("Saving with bitdepth: " + str(self.bit_depth))
dir_explore = QFileDialog()
options = QFileDialog.Options()
filepath, _ = QFileDialog.getSaveFileName(
self, "QFileDialog.getSaveFileName()", "", "Tiff Files (*.tif *.tiff)", options=options
)
if filepath:
msg = QMessageBox()
msg.setIcon(QMessageBox.Information)
msg.setWindowTitle("Saving Images...")
msg.setText("Saving Images to BigTiff")
msg.show()
# self.apply_histogram_to_images()
bit_depth_string = self.check_bit_depth(self.bit_depth)
tifffile.imwrite(filepath, self.tiff_arr, bigtiff=True, dtype=bit_depth_string)
msg.close()
def min_spin_changed(self):
"""
Changes the levels of the histogram widget if the min spinbox has been changed
:return: None
"""
histo = self.image_window.getHistogramWidget()
levels = histo.getLevels()
min_level = self.hist_min_input.value()
self.image_window.setLevels(min_level, levels[1])
def max_spin_changed(self):
"""
Changes the levels of the histogram widget if the max spinbox has been changed
:return: None
"""
histo = self.image_window.getHistogramWidget()
levels = histo.getLevels()
max_level = self.hist_max_input.value()
self.image_window.setLevels(levels[0], max_level)
def apply_histogram_button_clicked(self):
LOG.debug("Apply Histogram Button Clicked")
print("Applying histogram to images. This may take a moment.")
self.apply_histogram_to_images()
def apply_histogram_to_images(self):
"""
Gets the histogram levels of the currently displayed image and applies them to all images in RAM
:return: None
"""
levels = self.histo.getLevels()
self.tiff_arr = np.clip(self.tiff_arr, levels[0], levels[1])
def check_bit_depth(self, bit_depth: int) -> str:
"""
Returns a string indicating the bitdepth to store the images based on value of bit-depth radio buttons
:param bit_depth:
:return: String specifying datatype for numpy array
"""
if bit_depth == 8:
return "uint8"
elif bit_depth == 16:
return "uint16"
elif bit_depth == 32:
return "uint32"
def set_8bit(self):
"""
Sets value of bit_depth variable based on radio button selection
:return: None
"""
LOG.debug("Set 8-bit")
self.bit_depth = 8
def set_16bit(self):
"""
Sets value of bit_depth variable based on radio button selection
:return: None
"""
LOG.debug("Set 16-bit")
self.bit_depth = 16
def set_32bit(self):
"""
Sets value of bit_depth variable based on radio button selection
:return: None
"""
LOG.debug("Set 32-bit")
self.bit_depth = 32
tofu-0.12.0/tofu/ez/GUI/login_dialog.py 0000664 0000000 0000000 00000012452 14237137211 0017574 0 ustar 00root root 0000000 0000000 import re
from PyQt5.QtCore import Qt
from PyQt5.QtWidgets import (
QDialog,
QLineEdit,
QPushButton,
QLabel,
QGridLayout,
QFileDialog,
QComboBox,
)
from tofu.ez.GUI.message_dialog import error_message
import os
class Login(QDialog):
def __init__(self, login_parameters_dict, **kwargs):
super(Login, self).__init__(**kwargs)
# Pass a method from main GUI
self.login_parameters_dict = login_parameters_dict
self.setWindowTitle("USER LOGIN")
self.setWindowModality(Qt.ApplicationModal)
self.setAttribute(Qt.WA_DeleteOnClose)
self.welcome_label = QLabel()
self.welcome_label.setText("Welcome to BMIT!")
self.prompt_label_bl = QLabel()
self.prompt_label_bl.setText("Please select the beamline and project:")
self.bl_label = QLabel()
self.bl_label.setText("Beamline:")
self.bl_entry = QComboBox()
self.bl_entry.addItems(["BM", "ID"])
self.proj_label = QLabel()
self.proj_label.setText("Project:")
self.proj_entry = QLineEdit()
self.prompt_label_expdir = QLabel()
self.prompt_label_expdir.setText("OR select the path to the working directory")
self.expdir_entry = QLineEdit()
# self.expdir_entry.setText("/data/gui-test")
self.expdir_entry.setReadOnly(True)
self.expdir_select_button = QPushButton("...")
self.expdir_select_button.clicked.connect(self.select_expdir_func)
self.login_button = QPushButton("LOGIN")
self.login_button.clicked.connect(self.on_login_button_clicked)
self.set_layout()
def set_layout(self):
layout = QGridLayout()
self.welcome_label.setAlignment(Qt.AlignCenter)
self.prompt_label_bl.setAlignment(Qt.AlignCenter)
self.prompt_label_expdir.setAlignment(Qt.AlignCenter)
layout.addWidget(self.welcome_label, 0, 0, 1, 2)
layout.addWidget(self.prompt_label_bl, 1, 0, 1, 2)
layout.addWidget(self.bl_label, 2, 0, 1, 1)
layout.addWidget(self.bl_entry, 2, 1, 1, 1)
layout.addWidget(self.proj_label, 3, 0, 1, 1)
layout.addWidget(self.proj_entry, 3, 1, 1, 1)
layout.addWidget(self.prompt_label_expdir, 4, 0, 1, 2)
layout.addWidget(self.expdir_entry, 5, 0, 1, 1)
layout.addWidget(self.expdir_select_button, 5, 1, 1, 1)
layout.addWidget(self.login_button, 6, 0, 1, 2)
layout.setSpacing(15)
layout.setContentsMargins(25, 25, 25, 25)
self.setLayout(layout)
def select_expdir_func(self):
options = QFileDialog.Options()
options |= QFileDialog.DontUseNativeDialog
root_dir = QFileDialog.getExistingDirectory(
self, "Select working directory", "/data/gui-test", options=options
)
if root_dir:
self.expdir_entry.setText(root_dir)
def uppercase_project_entry(self):
self.proj_entry.setText(self.proj_entry.text().upper())
def strip_spaces_from_user_entry(self):
self.user_entry.setText(self.user_entry.text().replace(" ", ""))
@property
def project_name(self):
return self.proj_entry.text()
@property
def user_name(self):
return self.user_entry.text()
@property
def expdir_name(self):
return self.expdir_entry.text()
@property
def bl_name(self):
return self.bl_entry.currentText()
def validate_entries(self):
self.uppercase_project_entry()
# self.strip_spaces_from_user_entry()
project_valid = bool(re.match(r"^[0-9]{2}[A-Z][0-9]{5}$", self.project_name))
# username_valid = bool(re.match(r"^[a-zA-Z0-9]*$", self.user_name))
# return project_valid, username_valid
return project_valid
def validate_dir(self, pdr):
return os.access(pdr, os.W_OK)
def on_login_button_clicked(self):
# project_valid, username_valid = self.validate_entries()
if self.project_name != "":
prj_dir_name = os.path.join(
"/beamlinedata/BMIT/projects/prj" + self.project_name, "raw"
)
project_valid = self.validate_entries()
can_write = self.validate_dir(prj_dir_name)
if project_valid and can_write:
self.login_parameters_dict.update({"bl": self.bl_name})
self.login_parameters_dict.update({"project": self.project_name})
# add fileExistsError exception later in Py3
self.login_parameters_dict.update({"expdir": prj_dir_name})
self.accept()
# elif not username_valid:
# error_message("Username should be alpha-numeric ")
elif not project_valid:
error_message(
"The project should be in format: CCTNNNNN, \n"
"where CC is cycle number, "
"T is one-letter type, "
"and NNNNN is project number"
)
elif not can_write:
error_message("Cannot write in the project directory")
elif self.expdir_name != "":
if self.validate_dir(self.expdir_entry.text()):
self.login_parameters_dict.update({"expdir": self.expdir_name})
self.accept()
else:
error_message("Cannot write in the selected directory")
tofu-0.12.0/tofu/ez/GUI/message_dialog.py 0000664 0000000 0000000 00000000666 14237137211 0020114 0 ustar 00root root 0000000 0000000 from PyQt5.QtWidgets import QMessageBox
def message_dialog(window_title, message_text):
alert = QMessageBox()
alert.setWindowTitle(window_title)
alert.setText(message_text)
alert.exec_()
def error_message(message_text):
message_dialog("Error", message_text)
def warning_message(message_text):
message_dialog("Warning", message_text)
def info_message(message_text):
message_dialog("Info", message_text)
tofu-0.12.0/tofu/ez/Helpers/ 0000775 0000000 0000000 00000000000 14237137211 0015545 5 ustar 00root root 0000000 0000000 tofu-0.12.0/tofu/ez/Helpers/__init__.py 0000664 0000000 0000000 00000000000 14237137211 0017644 0 ustar 00root root 0000000 0000000 tofu-0.12.0/tofu/ez/Helpers/find_360_overlap.py 0000664 0000000 0000000 00000012420 14237137211 0021156 0 ustar 00root root 0000000 0000000 """
This script takes as input a CT scan that has been collected in "half-acquisition" mode
and produces a series of reconstructed slices, each of which are generated by cropping and
concatenating opposing projections together over a range of "overlap" values (i.e. the pixel column
at which the images are cropped and concatenated).
The objective is to review this series of images to determine the pixel column at which the axis of rotation
is located (much like the axis search function commonly used in reconstruction software).
"""
import os
import numpy as np
import tifffile
from tofu.ez.image_read_write import TiffSequenceReader
import tofu.ez.params as glob_parameters
from tofu.ez.Helpers.stitch_funcs import findCTdirs, stitch_float32_output
def extract_row(dir_name, row):
tsr = TiffSequenceReader(dir_name)
tmp = tsr.read(0)
(N, M) = tmp.shape
if (row < 0) or (row > N):
row = N//2
num_images = tsr.num_images
if num_images % 2 == 1:
print(f"odd number of images ({num_images}) in {dir_name}, "
f"discarding the last one before stitching pairs")
num_images-=1
A = np.empty((num_images, M), dtype=np.uint16)
for i in range(num_images):
A[i, :] = tsr.read(i)[row, :]
tsr.close()
return A
def find_overlap(parameters):
print("Finding CTDirs...")
ctdirs, lvl0 = findCTdirs(parameters['360overlap_input_dir'],
glob_parameters.params['main_config_tomo_dir_name'])
ctdirs.sort()
print(ctdirs)
# concatenate images end-to-end and generate a sinogram
for ctset in ctdirs:
print("Working on ctset:" + str(ctset))
index_dir = os.path.basename(os.path.normpath(ctset))
# loading:
print(os.path.join(ctset, glob_parameters.params['main_config_flats_dir_name']))
try:
row_flat = np.mean(extract_row(
os.path.join(ctset, glob_parameters.params['main_config_flats_dir_name']),
parameters['360overlap_row']))
except:
print(f"Problem loading flats in {ctset}")
continue
try:
row_dark = np.mean(extract_row(
os.path.join(ctset, glob_parameters.params['main_config_darks_dir_name']),
parameters['360overlap_row']))
except:
print(f"Problem loading darks in {ctset}")
continue
try:
row_tomo = extract_row(
os.path.join(ctset, glob_parameters.params['main_config_tomo_dir_name']),
parameters['360overlap_row'])
except:
print(f"Problem loading projections from "
f"{os.path.join(ctset, glob_parameters.params['main_config_tomo_dir_name'])}")
continue
row_flat2 = None
tmpstr = os.path.join(ctset, glob_parameters.params['main_config_flats2_dir_name'])
if os.path.exists(tmpstr):
try:
row_flat2 = np.mean(extract_row(tmpstr, parameters['360overlap_row']))
except:
print(f"Problem loading flats2 in {ctset}")
(num_proj, M) = row_tomo.shape
print('Flat-field correction...')
# Flat-correction
tmp_flat = np.tile(row_flat, (num_proj, 1))
if row_flat2 is not None:
tmp_flat2 = np.tile(row_flat2, (num_proj, 1))
ramp = np.linspace(0, 1, num_proj)
ramp = np.transpose(np.tile(ramp, (M, 1)))
tmp_flat = tmp_flat * (1-ramp) + tmp_flat2 * ramp
del ramp, tmp_flat2
tmp_dark = np.tile(row_dark, (num_proj, 1))
tomo_ffc = -np.log((row_tomo - tmp_dark)/np.float32(tmp_flat - tmp_dark))
del row_tomo, row_dark, row_flat, tmp_flat, tmp_dark
np.nan_to_num(tomo_ffc, copy=False, nan=0.0, posinf=0.0, neginf=0.0)
# create interpolated sinogram of flats on the
# same row as we use for the projections, then flat/dark correction
print('Creating stitched sinograms...')
sin_tmp_dir = os.path.join(parameters['360overlap_temp_dir'], index_dir)
os.makedirs(sin_tmp_dir)
for axis in range(parameters['360overlap_lower_limit'],
parameters['360overlap_upper_limit'],
parameters['360overlap_increment']):
cro = parameters['360overlap_upper_limit'] - axis
A = stitch_float32_output(
tomo_ffc[: num_proj//2, :], tomo_ffc[num_proj//2:, ::-1], axis, cro)
tifffile.imsave(os.path.join(
sin_tmp_dir, 'sin-axis-' + str(axis).zfill(4) + '.tif'), A.astype(np.float32))
# perform reconstructions for each sinogram and save to output folder
print('Reconstructing slices...')
reco_axis = M-parameters['360overlap_upper_limit']
cmd = f'tofu tomo --axis {reco_axis} --sinograms {sin_tmp_dir}'
cmd +=' --output '+os.path.join(os.path.join(
parameters['360overlap_output_dir'], f"{index_dir}-sli.tif"))
print(cmd)
os.system(cmd)
print("Finished processing: " + str(index_dir))
print("********************DONE********************")
#shutil.rmtree(parameters['360overlap_temp_dir'])
print("Finished processing: " + str(parameters['360overlap_input_dir']))
tofu-0.12.0/tofu/ez/Helpers/mview_main.py 0000664 0000000 0000000 00000006542 14237137211 0020261 0 ustar 00root root 0000000 0000000 #!/bin/python
import os
import numpy
from tofu.util import get_filenames
import re
def check_folders(p, noflats2):
if not os.path.exists(p):
os.makedirs(p)
tmp = p + "/darks"
if not os.path.exists(tmp):
os.makedirs(tmp)
tmp = p + "/flats"
if not os.path.exists(tmp):
os.makedirs(tmp)
if noflats2 == False:
tmp = p + "/flats2"
if not os.path.exists(tmp):
os.makedirs(tmp)
tmp = p + "/tomo"
if not os.path.exists(tmp):
os.makedirs(tmp)
def rename_Andor(args):
names = get_filenames(os.path.join(args.input, "*.tif"))
maxnum = re.match(".*?([0-9]+)$", names[0][:-4]).group(1)
n_dgts = len(maxnum)
trnc_len = n_dgts + 4
prefix = names[0][:-trnc_len]
maxnum = int(maxnum)
for name in names:
num = int(re.match(".*?([0-9]+)$", name[:-4]).group(1))
maxnum = num if (num > maxnum) else maxnum
n_dgts = len(str(maxnum))
lin_fmt = prefix + "{:0" + str(n_dgts) + "}.tif"
for name in names:
num = re.match(".*?([0-9]+)$", name[:-4]).group(1)
if name == lin_fmt.format(int(num)):
continue
else:
cmd = "mv {} {}".format(name, lin_fmt.format(int(num)))
os.system(cmd)
def main_prep(args):
if args.Andor:
rename_Andor(args)
frames = get_filenames(os.path.join(args.input, "*.tif"))
nframes = len(frames)
if nframes == 0:
tmp = "Check INPUT directory: there are no tif files there"
raise ValueError(tmp)
# replace first frame with the second to get rid of
# corrupted first file in the PCO Edge sequencies
cmd = "rm {}; cp {} {}".format(frames[0], frames[1], frames[0])
os.system(cmd)
FFinterval = args.nproj
int_tot = args.nviews # (args.nproj/FFinterval)*args.nviews
int_1view = 1.0 # args.nproj/FFinterval
files_in_int = args.nflats + args.ndarks + FFinterval
files_input = (args.nflats + args.ndarks + FFinterval) * int_tot
if args.noflats2 == False:
files_input += args.nflats + args.ndarks
if files_input != nframes:
tmp = (
"Sequence length (found {} files) does not match ".format(nframes)
+ "one calculated from input parameters "
+ "(expected {} files)".format(files_input)
)
raise ValueError(tmp)
for i in range(args.nviews):
if args.nviews > 1:
pout = os.path.join(args.output, "z{:02d}".format(i))
else:
pout = args.output
check_folders(pout, args.noflats2)
# offset to heading flats and darks
o = i * files_in_int
for i in range(args.nflats):
cmd = "mv {} {}/flats/".format(frames[o + i], pout)
os.system(cmd)
# print(cmd)
o += args.nflats
for i in range(args.ndarks):
cmd = "mv {} {}/darks/".format(frames[o + i], pout)
os.system(cmd)
# print(cmd)
o += args.ndarks
for i in range(args.nproj):
cmd = "mv {} {}/tomo/".format(frames[o + i], pout)
os.system(cmd)
# print(cmd)
o += args.nproj
if args.noflats2:
continue
for i in range(args.nflats):
cmd = "cp {} {}/flats2/".format(frames[o + i], pout)
os.system(cmd)
# print(cmd)
print("========== Done ==========")
tofu-0.12.0/tofu/ez/Helpers/stitch_funcs.py 0000664 0000000 0000000 00000047422 14237137211 0020624 0 ustar 00root root 0000000 0000000 """
Last modified on Apr 1, 2022
@author: sergei gasilov
"""
import glob
import os
import shutil
import numpy as np
import tifffile
from tofu.util import read_image
import multiprocessing as mp
from functools import partial
import re
import warnings
import time
import tofu.ez.params as glob_parameters
def findCTdirs(root: str, tomo_name: str):
"""
Walks directories rooted at "Input ctset" location
Appends their absolute path to ctdir if they contain a ctset with same name as "tomo" entry in GUI
"""
lvl0 = os.path.abspath(root)
ctdirs = []
for root, dirs, files in os.walk(lvl0):
for name in dirs:
if name == tomo_name:
ctdirs.append(root)
return ctdirs, lvl0
def prepare(parameters, dir_type: int, ctdir: str):
"""
:param parameters: GUI params
:param dir_type 1 if CTDir containing Z00-Z0N slices - 2 if parent directory containing CTdirs each containing Z slices:
:param ctdir Name of the ctdir - blank string if not using multiple ctdirs:
:return:
"""
hmin, hmax = 0.0, 0.0
if parameters['ezstitch_clip_histo']:
if parameters['ezstitch_histo_min'] == parameters['ezstitch_histo_max']:
raise ValueError(' - Define hmin and hmax correctly in order to convert to 8bit')
else:
hmin, hmax = parameters['ezstitch_histo_min'], parameters['ezstitch_histo_max']
start, stop, step = [int(value) for value in parameters['ezstitch_start_stop_step'].split(',')]
if not os.path.exists(parameters['ezstitch_output_dir']):
os.makedirs(parameters['ezstitch_output_dir'])
Vsteps = sorted(os.listdir(os.path.join(parameters['ezstitch_input_dir'], ctdir)))
#determine input data type
if dir_type == 1:
tmp = os.path.join(parameters['ezstitch_input_dir'], Vsteps[0], parameters['ezstitch_type_image'], '*.tif')
tmp = sorted(glob.glob(tmp))[0]
indtype = type(read_image(tmp)[0][0])
elif dir_type == 2:
tmp = os.path.join(parameters['ezstitch_input_dir'], ctdir, Vsteps[0], parameters['ezstitch_type_image'], '*.tif')
tmp = sorted(glob.glob(tmp))[0]
indtype = type(read_image(tmp)[0][0])
if parameters['ezstitch_stitch_orthogonal']:
for vstep in Vsteps:
if dir_type == 1:
in_name = os.path.join(parameters['ezstitch_input_dir'], vstep, parameters['ezstitch_type_image'])
out_name = os.path.join(parameters['ezstitch_temp_dir'], vstep, parameters['ezstitch_type_image'], 'sli-%04i.tif')
elif dir_type == 2:
in_name = os.path.join(parameters['ezstitch_input_dir'], ctdir, vstep, parameters['ezstitch_type_image'])
out_name = os.path.join(parameters['ezstitch_temp_dir'], ctdir, vstep, parameters['ezstitch_type_image'], 'sli-%04i.tif')
cmd = 'tofu sinos --projections {} --output {}'.format(in_name, out_name)
cmd += " --y {} --height {} --y-step {}".format(start, stop-start, step)
cmd += " --output-bytes-per-file 0"
os.system(cmd)
time.sleep(10)
indir = parameters['ezstitch_temp_dir']
else:
indir = parameters['ezstitch_input_dir']
return indir, hmin, hmax, start, stop, step, indtype
def exec_sti_mp(start, step, N, Nnew, Vsteps, indir, dx, M, parameters, ramp, hmin, hmax, indtype, ctdir, dir_type, j):
index = start+j*step
Large = np.empty((Nnew*len(Vsteps)+dx, M), dtype=np.float32)
for i, vstep in enumerate(Vsteps[:-1]):
if dir_type == 1:
tmp = os.path.join(indir, Vsteps[i], parameters['ezstitch_type_image'], '*.tif')
tmp1 = os.path.join(indir, Vsteps[i+1], parameters['ezstitch_type_image'], '*.tif')
elif dir_type == 2:
tmp = os.path.join(indir, ctdir, Vsteps[i], parameters['ezstitch_type_image'], '*.tif')
tmp1 = os.path.join(indir, ctdir, Vsteps[i + 1], parameters['ezstitch_type_image'], '*.tif')
if parameters['ezstitch_stitch_orthogonal']:
tmp = sorted(glob.glob(tmp))[j]
tmp1 = sorted(glob.glob(tmp1))[j]
else:
tmp = sorted(glob.glob(tmp))[index]
tmp1 = sorted(glob.glob(tmp1))[index]
first = read_image(tmp)
second = read_image(tmp1)
# sample moved downwards
if parameters['ezstitch_sample_moved_down']:
first, second = np.flipud(first), np.flipud(second)
k = np.mean(first[N - dx:, :]) / np.mean(second[:dx, :])
second = second * k
a, b, c = i*Nnew, (i+1)*Nnew, (i+2)*Nnew
Large[a:b, :] = first[:N-dx, :]
Large[b:b+dx, :] = np.transpose(np.transpose(first[N-dx:, :])*(1 - ramp) +
np.transpose(second[:dx, :]) * ramp)
Large[b+dx:c+dx, :] = second[dx:, :]
pout = os.path.join(parameters['ezstitch_output_dir'],
ctdir,
parameters['ezstitch_type_image']+'-sti-{:>04}.tif'.format(index))
if not parameters['ezstitch_clip_histo']:
tifffile.imsave(pout, Large.astype(indtype))
else:
Large = 255.0/(hmax-hmin) * (np.clip(Large, hmin, hmax) - hmin)
tifffile.imsave(pout, Large.astype(np.uint8))
def main_sti_mp(parameters):
#Check whether indir is CTdir or parent containing CTdirs
#if indir + some z00 subdir + sli + *.tif does not exist then use original
subdirs = sorted(os.listdir(parameters['ezstitch_input_dir']))
if os.path.exists(os.path.join(parameters['ezstitch_input_dir'], subdirs[0], parameters['ezstitch_type_image'])):
dir_type = 1
ctdir = ""
print(" - Using CT directory containing slices")
if parameters['ezstitch_stitch_orthogonal']:
print(" - Creating orthogonal sections")
indir, hmin, hmax, start, stop, step, indtype = prepare(parameters, dir_type, "")
dx = int(parameters['ezstitch_num_overlap_rows'])
# second: stitch them
Vsteps = sorted(os.listdir(indir))
tmp = glob.glob(os.path.join(indir, Vsteps[0], parameters['ezstitch_type_image'], '*.tif'))[0]
first = read_image(tmp)
N, M = first.shape
Nnew = N - dx
ramp = np.linspace(0, 1, dx)
J = range((stop - start) // step)
pool = mp.Pool(processes=mp.cpu_count())
# ??? IT was OK back in 2.7 but now can crash
# if pool size is larger than array being multiprocessed?
exec_func = partial(exec_sti_mp, start, step, N, Nnew, \
Vsteps, indir, dx, M, parameters, ramp, hmin, hmax, indtype, ctdir, dir_type)
print(" - Adjusting and stitching")
# start = time.time()
pool.map(exec_func, J)
print("========== Done ==========")
else:
second_subdirs = sorted(os.listdir(os.path.join(parameters['ezstitch_input_dir'], subdirs[0])))
if os.path.exists(os.path.join(parameters['ezstitch_input_dir'], subdirs[0], second_subdirs[0], parameters['ezstitch_type_image'])):
print(" - Using parent directory containing CT directories, each of which contains slices")
dir_type = 2
#For each subdirectory do the same thing
for ctdir in subdirs:
print("-> Working on " + str(ctdir))
if not os.path.exists(os.path.join(parameters['ezstitch_output_dir'], ctdir)):
os.makedirs(os.path.join(parameters['ezstitch_output_dir'], ctdir))
if parameters['ezstitch_stitch_orthogonal']:
print(" - Creating orthogonal sections")
indir, hmin, hmax, start, stop, step, indtype = prepare(parameters, dir_type, ctdir)
dx = int(parameters['ezstitch_num_overlap_rows'])
# second: stitch them
Vsteps = sorted(os.listdir(os.path.join(indir, ctdir)))
tmp = glob.glob(os.path.join(indir, ctdir, Vsteps[0], parameters['ezstitch_type_image'], '*.tif'))[0]
first = read_image(tmp)
N, M = first.shape
Nnew = N - dx
ramp = np.linspace(0, 1, dx)
J = range(int((stop - start) / step))
pool = mp.Pool(processes=mp.cpu_count())
exec_func = partial(exec_sti_mp, start, step, N, Nnew, \
Vsteps, indir, dx, M, parameters, ramp, hmin, hmax, indtype, ctdir, dir_type)
print(" - Adjusting and stitching")
# start = time.time()
pool.map(exec_func, J)
print("========== Done ==========")
# Clear temp directory
clear_tmp(parameters)
else:
print("Invalid input directory")
complete_message()
def make_buf(tmp, l, a, b):
first = read_image(tmp)
N, M = first[a:b, :].shape
return np.empty((N*l, M), dtype=first.dtype), N, first.dtype
def exec_conc_mp(start, step, example_im, l, parameters, zfold, indir, ctdir, j):
index = start+j*step
Large, N, dtype = make_buf(example_im, l, parameters['ezstitch_first_row'], parameters['ezstitch_last_row'])
for i, vert in enumerate(zfold):
tmp = os.path.join(indir, ctdir, vert, parameters['ezstitch_type_image'], '*.tif')
if parameters['ezstitch_stitch_orthogonal']:
fname=sorted(glob.glob(tmp))[j]
else:
fname=sorted(glob.glob(tmp))[index]
frame = read_image(fname)[parameters['ezstitch_first_row']:parameters['ezstitch_last_row'], :]
if parameters['ezstitch_sample_moved_down']:
Large[i*N:N*(i+1), :] = np.flipud(frame)
else:
Large[i*N:N*(i+1), :] = frame
pout = os.path.join(parameters['ezstitch_output_dir'], ctdir, parameters['ezstitch_type_image']+'-sti-{:>04}.tif'.format(index))
#print "input data type {:}".format(dtype)
tifffile.imsave(pout, Large)
def main_conc_mp(parameters):
# Check whether indir is CTdir or parent containing CTdirs
# if indir + some z00 subdir + sli + *.tif does not exist then use original
subdirs = sorted(os.listdir(parameters['ezstitch_input_dir']))
if os.path.exists(os.path.join(parameters['ezstitch_input_dir'], subdirs[0], parameters['ezstitch_type_image'])):
dir_type = 1
ctdir = ""
print(" - Using CT directory containing slices")
if parameters['ezstitch_stitch_orthogonal']:
print(" - Creating orthogonal sections")
#start = time.time()
indir, hmin, hmax, start, stop, step, indtype = prepare(parameters, dir_type, ctdir)
subdirs = [dI for dI in os.listdir(parameters['ezstitch_input_dir']) if os.path.isdir(os.path.join(parameters['ezstitch_input_dir'], dI))]
zfold = sorted(subdirs)
l = len(zfold)
tmp = glob.glob(os.path.join(indir, zfold[0], parameters['ezstitch_type_image'], '*.tif'))
J = range((stop-start)//step)
pool = mp.Pool(processes=mp.cpu_count())
exec_func = partial(exec_conc_mp, start, step, tmp[0], l, parameters, zfold, indir, ctdir)
print(" - Concatenating")
#start = time.time()
pool.map(exec_func, J)
#print "Images stitched in {:.01f} sec".format(time.time()-start)
print("============ Done ============")
else:
second_subdirs = sorted(os.listdir(os.path.join(parameters['ezstitch_input_dir'], subdirs[0])))
if os.path.exists(os.path.join(parameters['ezstitch_input_dir'], subdirs[0], second_subdirs[0], parameters['ezstitch_type_image'])):
print(" - Using parent directory containing CT directories, each of which contains slices")
dir_type = 2
for ctdir in subdirs:
print(" == Working on " + str(ctdir) + " ==")
if not os.path.exists(os.path.join(parameters['ezstitch_output_dir'], ctdir)):
os.makedirs(os.path.join(parameters['ezstitch_output_dir'], ctdir))
if parameters['ezstitch_stitch_orthogonal']:
print(" - Creating orthogonal sections")
# start = time.time()
indir, hmin, hmax, start, stop, step, indtype = prepare(parameters, dir_type, ctdir)
zfold = sorted(os.listdir(os.path.join(indir, ctdir)))
l = len(zfold)
tmp = glob.glob(os.path.join(indir, ctdir, zfold[0], parameters['ezstitch_type_image'], '*.tif'))
J = range((stop - start) // step)
pool = mp.Pool(processes=mp.cpu_count())
exec_func = partial(exec_conc_mp, start, step, tmp[0], l, parameters, zfold, indir, ctdir)
print(" - Concatenating")
# start = time.time()
pool.map(exec_func, J)
# print "Images stitched in {:.01f} sec".format(time.time()-start)
print("============ Done ============")
#Clear temp directory
clear_tmp(parameters)
complete_message()
############################## HALF ACQ ##############################
def stitch(first, second, axis, crop):
h, w = first.shape
if axis > w / 2:
dx = int(2 * (w - axis) + 0.5)
else:
dx = int(2 * axis + 0.5)
tmp = np.copy(first)
first = second
second = tmp
result = np.empty((h, 2 * w - dx), dtype=first.dtype)
ramp = np.linspace(0, 1, dx)
# Mean values of the overlapping regions must match, which corrects flat-field inconsistency
# between the two projections
# We clip the values in second so that there are no saturated pixel overflow problems
k = np.mean(first[:, w - dx:]) / np.mean(second[:, :dx])
second = np.clip(second * k, np.iinfo(np.uint16).min, np.iinfo(np.uint16).max).astype(np.uint16)
result[:, :w - dx] = first[:, :w - dx]
result[:, w - dx:w] = first[:, w - dx:] * (1 - ramp) + second[:, :dx] * ramp
result[:, w:] = second[:, dx:]
return result[:, slice(int(crop), int(2*(w - axis) - crop), 1)]
############################## HALF ACQ ##############################
def stitch_float32_output(first, second, axis, crop):
print(f"Stitching two halves with axis {axis}")
h, w = first.shape
if axis > w / 2:
dx = int(2 * (w - axis) + 0.5)
else:
dx = int(2 * axis + 0.5)
tmp = np.copy(first)
first = second
second = tmp
result = np.empty((h, 2 * w - dx), dtype=first.dtype)
ramp = np.linspace(0, 1, dx)
# Mean values of the overlapping regions must match, which corrects flat-field inconsistency
# between the two projections
# We clip the values in second so that there are no saturated pixel overflow problems
k = np.mean(first[:, w - dx:]) / np.mean(second[:, :dx])
result[:, :w - dx] = first[:, :w - dx]
result[:, w - dx:w] = first[:, w - dx:] * (1 - ramp) + second[:, :dx] * ramp
result[:, w:] = second[:, dx:] * k
return result[:, slice(int(crop), int(2*(w - axis) - crop), 1)]
def st_mp_idx(offst, ax, crop, in_fmt, out_fmt, idx):
#we pass index and formats as argument
first = read_image(in_fmt.format(idx))
second = read_image(in_fmt.format(idx+offst))[:, ::-1]
stitched = stitch(first, second, ax, crop)
tifffile.imwrite(out_fmt.format(idx), stitched)
def main_360_mp_depth1(indir, outdir, ax, cro):
if not os.path.exists(outdir):
os.makedirs(outdir)
subdirs = [dI for dI in os.listdir(indir) \
if os.path.isdir(os.path.join(indir, dI))]
for i, sdir in enumerate(subdirs):
print(f"Stitching images in {sdir}")
names = sorted(glob.glob(os.path.join(indir, sdir, '*.tif')))
num_projs = len(names)
if num_projs<2:
warnings.warn("Warning: less than 2 files")
print(str(num_projs) + ' files in ' + str(sdir))
os.makedirs(os.path.join(outdir, sdir))
out_fmt = os.path.join(outdir, sdir, 'sti-{:>04}.tif')
# extraxt input file format
firstfname = names[0]
firstnum = re.match('.*?([0-9]+)$', firstfname[:-4]).group(1)
n_dgts = len(firstnum) #number of significant digits
idx0 = int(firstnum)
trnc_len = n_dgts + 4 #format + .tif
in_fmt = firstfname[:-trnc_len] + '{:0'+str(n_dgts)+'}.tif'
offst = int(num_projs / 2)
exec_func = partial(st_mp_idx, offst, ax, cro, in_fmt, out_fmt)
idxs = range(idx0, idx0+offst)
pool = mp.Pool(processes=mp.cpu_count())
# double check if names correspond - to remove later
for nmi in idxs:
#print(names[nmi-idx0], in_fmt.format(nmi))
if names[nmi-idx0] != in_fmt.format(nmi):
print('Something wrong with file name format')
continue
#pool.map(exec_func, names[0:num_projs/2])
pool.map(exec_func, idxs)
print("========== Done ==========")
def main_360_mp_depth2(parameters):
ctdirs, lvl0 = findCTdirs(parameters['360multi_input_dir'], glob_parameters.params['main_config_tomo_dir_name'])
num_sets = len(ctdirs)
if num_sets < 1:
print(f"Didn't find any CT dirs in the input. Check directory structure and permissions. \n"
f"Program expects to see a number of subdirectories in the input each of with \n"
f"contains at least one directory with CT projections (currently name set to "
f"{glob_parameters.params['main_config_tomo_dir_name']}. \n"+
f"The tif files in all " \
f" {glob_parameters.params['main_config_tomo_dir_name']}, "
f" {glob_parameters.params['main_config_flats_dir_name']}, "
f" {glob_parameters.params['main_config_darks_dir_name']} \n"
f"subdirectories will be stitched to convert half-acquisition mode scans to ordinary \n"
f"180-deg parallel-beam scans")
return
tmp = len(parameters['360multi_input_dir'])
ctdirs_rel_paths = []
for i in range(num_sets):
ctdirs_rel_paths.append(ctdirs[i][tmp + 1: len(ctdirs[i])])
print(f"Found the {num_sets} directories in the input with relative paths: {ctdirs_rel_paths}")
# prepare axis and crop arrays
dax = np.round(np.linspace(parameters['360multi_bottom_axis'], parameters['360multi_top_axis'], num_sets))
if parameters['360multi_manual_axis']:
print(parameters['360multi_axis_dict'])
dax = np.array(list(parameters['360multi_axis_dict'].values()))
print(dax)
cra = np.max(dax)-dax
for i, ctdir in enumerate(ctdirs):
print("================================================================")
print(" -> Working On: " + str(ctdir))
print(f" axis position {dax[i]}, margin to crop {cra[i]} pixels")
main_360_mp_depth1(ctdir,
os.path.join(parameters['360multi_output_dir'], ctdirs_rel_paths[i]),
dax[i], cra[i])
# print(ctdir, os.path.join(parameters['360multi_output_dir'], ctdirs_rel_paths[i]), dax[i], cra[i])
def clear_tmp(parameters):
tmp_dirs = os.listdir(parameters['ezstitch_temp_dir'])
for tmp_dir in tmp_dirs:
shutil.rmtree(os.path.join(parameters['ezstitch_temp_dir'], tmp_dir))
def check_last_index(axis_list):
"""
Return the index of item in list immediately before first 'None' type
:param axis_list:
:return: the index of last non-None value
"""
last_index = 0
for index, item in enumerate(axis_list):
if item == 'None':
last_index = index - 1
return last_index
last_index = index
return last_index
def complete_message():
print(" __.-/|")
print(" \\`o_O'")
print(" =( )= +-----+")
print(" U| | FIN |")
print(" /\\ /\\ / | +-----+")
print(" ) /^\\) ^\\/ _)\\ |")
print(" ) /^\\/ _) \\ |")
print(" ) _ / / _) \\___|_")
print(" /\\ )/\\/ || | )_)\\___,|))")
print("< > |(,,) )__) |")
print(" || / \\)___)\\")
print(" | \\____( )___) )____")
print(" \\______(_______;;;)__;;;)") tofu-0.12.0/tofu/ez/RR_external.py 0000664 0000000 0000000 00000015041 14237137211 0016743 0 ustar 00root root 0000000 0000000 #!/usr/bin/env python2
"""
Created on Aug 3, 2018
@author: SGasilov
Initially it has been simplest median sorting
Replaced by non-FFT based methods proposed by
Nghia T. Vo and published in https://doi.org/10.1364/OE.26.028396
"""
import os
import argparse
from tofu.util import read_image
import numpy as np
from tofu.util import get_filenames
import multiprocessing as mp
from functools import partial
from scipy.ndimage import median_filter
from scipy.ndimage import binary_dilation
import tifffile
def write_tiff(file_name, data):
"""
The default TIFF writer which uses :py:mod:`tifffile` module.
Return the written file name.
"""
tifffile.imsave(file_name, data)
return file_name
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("--sinos", type=str, help="Input directory")
parser.add_argument("--mws", type=int, help="Window size for small rings (sorting algorithm)")
parser.add_argument("--mws2", type=int, help="Window size for large rings")
parser.add_argument("--snr", type=int, help="Median window size along columns")
parser.add_argument("--sort_only", type=int, help="Only sorting or both")
return parser.parse_args()
def RR_wide_sort(mws, mws2, snr, odir, fname):
filt_sin_name = os.path.join(odir, os.path.split(fname)[1])
im = read_image(fname).astype(np.float32)
im = remove_large_stripe(im, snr, mws2)
im = remove_stripe_based_sorting(im, mws)
write_tiff(filt_sin_name, im.astype(np.float32))
def RR_sort(mws, odir, fname):
filt_sin_name = os.path.join(odir, os.path.split(fname)[1])
write_tiff(
filt_sin_name,
remove_stripe_based_sorting(read_image(fname).astype(np.float32), mws).astype(np.float32),
)
def remove_stripe_based_sorting(sinogram, size, dim=1):
# taken from sarepy, Author: Nghia T. Vo https://doi.org/10.1364/OE.26.028396
"""
Remove stripe artifacts in a sinogram using the sorting technique,
algorithm 3 in Ref. [1]. Angular direction is along the axis 0.
Parameters
----------
sinogram : array_like
2D array. Sinogram image.
size : int
Window size of the median filter.
dim : {1, 2}, optional
Dimension of the window.
"""
sinogram = np.transpose(sinogram)
(nrow, ncol) = sinogram.shape
list_index = np.arange(0.0, ncol, 1.0)
mat_index = np.tile(list_index, (nrow, 1))
mat_comb = np.asarray(np.dstack((mat_index, sinogram)))
mat_sort = np.asarray([row[row[:, 1].argsort()] for row in mat_comb])
if dim == 2:
mat_sort[:, :, 1] = median_filter(mat_sort[:, :, 1], (size, size))
else:
mat_sort[:, :, 1] = median_filter(mat_sort[:, :, 1], (size, 1))
mat_sort_back = np.asarray([row[row[:, 0].argsort()] for row in mat_sort])
return np.transpose(mat_sort_back[:, :, 1])
def detect_stripe(list_data, snr):
# taken from sarepy, Author: Nghia T. Vo https://doi.org/10.1364/OE.26.028396
"""
Locate stripe positions using Algorithm 4 in Ref. [1].
Parameters
----------
list_data : array_like
1D array. Normalized data.
snr : float
Ratio used to segment stripes from background noise.
"""
npoint = len(list_data)
list_sort = np.sort(list_data)
listx = np.arange(0, npoint, 1.0)
ndrop = np.int16(0.25 * npoint)
(slope, intercept) = np.polyfit(listx[ndrop : -ndrop - 1], list_sort[ndrop : -ndrop - 1], 1)
y_end = intercept + slope * listx[-1]
noise_level = np.abs(y_end - intercept)
noise_level = np.clip(noise_level, 1e-6, None)
val1 = np.abs(list_sort[-1] - y_end) / noise_level
val2 = np.abs(intercept - list_sort[0]) / noise_level
list_mask = np.zeros(npoint, dtype=np.float32)
if val1 >= snr:
upper_thresh = y_end + noise_level * snr * 0.5
list_mask[list_data > upper_thresh] = 1.0
if val2 >= snr:
lower_thresh = intercept - noise_level * snr * 0.5
list_mask[list_data <= lower_thresh] = 1.0
return list_mask
def remove_large_stripe(sinogram, size, snr=3, drop_ratio=0.1, norm=True):
# taken from sarepy, Author: Nghia T. Vo https://doi.org/10.1364/OE.26.028396
"""
Remove large stripes, algorithm 5 in Ref. [1], by: locating stripes,
normalizing to remove full stripes, and using the sorting technique
(Ref. [1]) to remove partial stripes. Angular direction is along the
axis 0.
Parameters
----------
sinogram : array_like
2D array. Sinogram image
snr : float
Ratio used to segment stripes from background noise.
size : int
Window size of the median filter.
drop_ratio : float, optional
Ratio of pixels to be dropped, which is used to reduce the false
detection of stripes.
norm : bool, optional
Apply normalization if True.
"""
sinogram = np.copy(sinogram) # Make it mutable
drop_ratio = np.clip(drop_ratio, 0.0, 0.8)
(nrow, ncol) = sinogram.shape
ndrop = int(0.5 * drop_ratio * nrow)
sino_sort = np.sort(sinogram, axis=0)
sino_smooth = median_filter(sino_sort, (1, size))
list1 = np.mean(sino_sort[ndrop : nrow - ndrop], axis=0)
list2 = np.mean(sino_smooth[ndrop : nrow - ndrop], axis=0)
list_fact = np.divide(list1, list2, out=np.ones_like(list1), where=list2 != 0)
list_mask = detect_stripe(list_fact, snr)
list_mask = np.float32(binary_dilation(list_mask, iterations=1))
mat_fact = np.tile(list_fact, (nrow, 1))
if norm is True:
sinogram = sinogram / mat_fact # Normalization
sino_tran = np.transpose(sinogram)
list_index = np.arange(0.0, nrow, 1.0)
mat_index = np.tile(list_index, (ncol, 1))
mat_comb = np.asarray(np.dstack((mat_index, sino_tran)))
mat_sort = np.asarray([row[row[:, 1].argsort()] for row in mat_comb])
mat_sort[:, :, 1] = np.transpose(sino_smooth)
mat_sort_back = np.asarray([row[row[:, 0].argsort()] for row in mat_sort])
sino_cor = np.transpose(mat_sort_back[:, :, 1])
listx_miss = np.where(list_mask > 0.0)[0]
sinogram[:, listx_miss] = sino_cor[:, listx_miss]
return sinogram
def main():
args = parse_args()
sinos = get_filenames(os.path.join(args.sinos, "*.tif"))
# create output directory
wdir = os.path.split(args.sinos)[0]
odir = os.path.join(wdir, "sinos-filt")
if not os.path.exists(odir):
os.makedirs(odir)
pool = mp.Pool(processes=mp.cpu_count())
if args.sort_only:
exec_func = partial(RR_sort, args.mws, odir)
else:
exec_func = partial(RR_wide_sort, args.mws, args.mws2, args.snr, odir)
pool.map(exec_func, sinos)
if __name__ == "__main__":
main()
tofu-0.12.0/tofu/ez/__init__.py 0000664 0000000 0000000 00000000001 14237137211 0016243 0 ustar 00root root 0000000 0000000
tofu-0.12.0/tofu/ez/ctdir_walker.py 0000664 0000000 0000000 00000016006 14237137211 0017172 0 ustar 00root root 0000000 0000000 """
Created on Apr 5, 2018
@author: gasilos
"""
import os
class WalkCTdirs:
"""
Walks in the directory structure and creates list of paths to CT folders
Determines flats before/after
and checks that folders contain only tiff files
fdt_names = flats/darks/tomo directory names
"""
def __init__(self, inpath, fdt_names, args, verb=True):
self.lvl0 = os.path.abspath(inpath)
self.ctdirs = []
self.types = []
self.ctsets = []
self.typ = []
self.total = 0
self.good = 0
self.verb = verb
self._fdt_names = fdt_names
self.common_flats = args.main_config_flats_path
self.common_darks = args.main_config_darks_path
self.common_flats2 = args.main_config_flats2_path
self.use_common_flats2 = args.main_config_flats2_checkbox
def print_tree(self):
print("We start in {}".format(self.lvl0))
def findCTdirs(self):
"""
Walks directories rooted at "Input Directory" location
Appends their absolute path to ctdir if they contain a directory with same name as "tomo" entry in GUI
"""
for root, dirs, files in os.walk(self.lvl0):
for name in dirs:
if name == self._fdt_names[2]:
self.ctdirs.append(root)
self.ctdirs = list(set(self.ctdirs))
def checkCTdirs(self):
"""
Determine whether directory is of type 3 or type 4 and store in self.typ with index corresponding to ctdir
Type3: Has flats, darks and not flats2 -- or flats==flats2
Type4: Has flats, darks and flats2
"""
for ctdir in self.ctdirs:
# flats/darks and no flats2 or flats2==flats -> type 3
if (
os.path.exists(os.path.join(ctdir, self._fdt_names[1]))
and os.path.exists(os.path.join(ctdir, self._fdt_names[0]))
and (
not os.path.exists(os.path.join(ctdir, self._fdt_names[3]))
or self._fdt_names[1] == self._fdt_names[3]
)
):
self.typ.append(3)
# flats/darks/flats2 -> type4
elif (
os.path.exists(os.path.join(ctdir, self._fdt_names[1]))
and os.path.exists(os.path.join(ctdir, self._fdt_names[0]))
and os.path.exists(os.path.join(ctdir, self._fdt_names[3]))
):
self.typ.append(4)
else:
print(os.path.basename(ctdir))
self.typ.append(0)
def checkCommonFDT(self):
"""
Verifies that paths to directories specified by common_flats, common_darks, and common_flats2 exist
:return: True if directories exist, False if they do not exist
"""
for ctdir in self.ctdirs:
if self.use_common_flats2 is True:
self.typ.append(4)
elif self.use_common_flats2 is False:
self.typ.append(3)
if self.use_common_flats2 is True:
if (
os.path.exists(self.common_flats)
and os.path.exists(self.common_darks)
and os.path.exists(self.common_flats2)
):
return True
elif self.use_common_flats2 is False:
if (os.path.exists(self.common_flats)
and os.path.exists(self.common_darks)):
return True
return False
def checkCommonFDTFiles(self):
"""
Verifies that directories of tomo and common flats/darks/flats contain only .tif files
:return: True if directories exist, False if they do not exist
"""
for i, ctdir in enumerate(self.ctdirs):
ctdir_tomo_path = os.path.join(ctdir, self._fdt_names[2])
if not self._checkTifs(ctdir_tomo_path):
print("Invalid files found in " + str(ctdir_tomo_path))
self.typ[i] = 0
return False
if not self._checkTifs(self.common_flats):
print("Invalid files found in " + str(self.common_flats))
return False
if not self._checkTifs(self.common_darks):
print("Invalid files found in " + str(self.common_darks))
return False
if self.use_common_flats2 and not self._checkTifs(self.common_flats2):
print("Invalid files found in " + str(self.common_flats2))
return False
return True
def checkCTfiles(self):
"""
Checks whether each ctdir is of type 3 or 4 by comparing index of self.typ[] to corresponding index of ctdir[]
Then for each directory of type 3 or 4 it checks sub-directories contain only .tif files
If it contains invalid data then typ[] is set to 0 for corresponding index location
"""
for i, ctdir in enumerate(self.ctdirs):
if (
self.typ[i] == 3
and self._checkTifs(os.path.join(ctdir, self._fdt_names[1]))
and self._checkTifs(os.path.join(ctdir, self._fdt_names[0]))
and self._checkTifs(os.path.join(ctdir, self._fdt_names[2]))
):
continue
elif (
self.typ[i] == 4
and self._checkTifs(os.path.join(ctdir, self._fdt_names[1]))
and self._checkTifs(os.path.join(ctdir, self._fdt_names[0]))
and self._checkTifs(os.path.join(ctdir, self._fdt_names[2]))
and self._checkTifs(os.path.join(ctdir, self._fdt_names[3]))
):
continue
else:
self.typ[i] = 0
def _checkTifs(self, tmpath):
"""
Checks each whether item in directory tmppath is a .tif file
:param tmpath: Path to directory
:return: 0 if invalid item found in directory - 1 if no invalid items found in directory
"""
for i in os.listdir(tmpath):
if os.path.isdir(i):
return 0
if i.split(".")[-1] != "tif":
return 0
return 1
def SortBadGoodSets(self):
"""
Reduces type of all directories to either
Good with flats 2 (1) or good without flats2 (0) or bad (<0)
"""
self.total = len(self.ctdirs)
self.ctsets = sorted(zip(self.ctdirs, self.typ), key=lambda s: s[0])
self.total = len(self.ctsets)
self.good = [int(y) > 2 for x, y in self.ctsets].count(True)
tmp = len(self.lvl0)
if self.verb:
print("Total folders {}, good folders {}".format(self.total, self.good))
print("{:>20}\t{}".format("Path to CT set", "Typ: 0 bad, 3 no flats2, 4 with flats2"))
for ctdir in self.ctsets:
msg1 = ctdir[0][tmp:]
if msg1 == "":
msg1 = "."
print("{:>20}\t{}".format(msg1, ctdir[1]))
# keep paths to directories with good ct data only:
self.ctsets = [q for q in self.ctsets if int(q[1] > 0)]
def Getlvl0(self):
return self.lvl0
tofu-0.12.0/tofu/ez/evaluate_sharpness.py 0000664 0000000 0000000 00000031723 14237137211 0020417 0 ustar 00root root 0000000 0000000 import argparse
import glob
import multiprocessing
import os
import time
import numpy as np
from functools import partial
from tofu.util import read_image
from scipy.stats import skew, kurtosis
def sum_abs_gradient(data):
"""Sum of absolute gradients."""
return np.sum(np.abs(np.gradient(data)))
def mad(data):
"""Median absolute deviation."""
return np.median(np.abs(data - np.median(data)))
def abs_sum(data):
"""Sum of the absolute values."""
return np.sum(np.abs(data))
def entropy(data, bins=256):
"""Image entropy."""
hist, bins = np.histogram(data, bins=bins)
hist = hist.astype(np.float)
hist /= hist.sum()
valid = np.where(hist > 0)
return -np.sum(np.dot(hist[valid], np.log2(hist[valid])))
def inverted(func, *args, **kwargs):
"""Return -func(*args, **kwargs)."""
return -func(*args, **kwargs)
def filter_data(data, fwhm=32.0):
"""Filter low frequencies in 1D *data* (needed when the axis is far away by axis evaluation).
*fwhm* is the FWHM of the gaussian used to filter out low frequencies in real space. The window
is then computed as fft(1 - gauss).
"""
mean = np.mean(data)
sigma = fwhm / (2 * np.sqrt(2 * np.log(2)))
# We compute the gaussian in Fourier space, so convert sigma first
f_sigma = 1.0 / (2 * np.pi * sigma)
x = np.fft.fftfreq(len(data))
fltr = 1 - np.exp(-(x ** 2) / (2 * f_sigma ** 2))
return np.fft.ifft(np.fft.fft(data) * fltr).real + mean
METRICS_1D = {
"mean": np.mean,
"std": np.std,
"skew": skew,
"kurtosis": kurtosis,
"mad": mad,
"asum": abs_sum,
"min": np.min,
"max": np.max,
"entropy": entropy,
}
METRICS_2D = {"sag": sum_abs_gradient}
for key in list(METRICS_1D):
METRICS_1D["m" + key] = partial(inverted, METRICS_1D[key])
for key in list(METRICS_2D):
METRICS_2D["m" + key] = partial(inverted, METRICS_2D[key])
# for key in METRICS_1D.keys():
# METRICS_1D['m' + key] = partial(inverted, METRICS_1D[key])
# for key in METRICS_2D.keys():
# METRICS_2D['m' + key] = partial(inverted, METRICS_2D[key])
def evaluate(
image,
metrics_1d=None,
metrics_2d=None,
global_min=None,
global_max=None,
metrics_1d_kwargs=None,
blur_fwhm=None,
):
"""Evaluate *metrics_1d* which work on a flattened image and *metrics_2d* in an *image* which
can either be a file path or an imageIf the metrics are None all the default ones are used.
*global_min* and *global_max* are the mean extrema of the whole sequence used to cut off outlier
values. Extrema are used only by 1d metrics. *metrics_1d_kwargs* are additional keyword
arguments passed to the functions, they are specified in dictioinary {func_name: kwargs}.
"""
if metrics_1d is None:
metrics_1d = METRICS_1D
if metrics_2d is None:
metrics_2d = METRICS_2D
results = {}
if type(image) == str:
image = read_image(image)
if blur_fwhm:
from scipy.ndimage import gaussian_filter
image = gaussian_filter(image, blur_fwhm / (2 * np.sqrt(2 * np.log(2))))
if global_min is None or global_max is None:
flattened = image.flatten()
else:
# Use global cutoff
flattened = image[np.where((image >= global_min) & (image <= global_max))]
if metrics_1d is not None:
for metric in metrics_1d:
kwargs = {}
if metrics_1d_kwargs and metric in metrics_1d_kwargs:
kwargs = metrics_1d_kwargs[metric]
results[metric] = metrics_1d[metric](flattened, **kwargs)
if metrics_2d is not None:
for metric in metrics_2d:
results[metric] = metrics_2d[metric](image)
return results
def evaluate_metrics(images, out_prefix, *args, **kwargs):
"""Evaluate many *images* which are either file paths or images. *out_prefix* is the metric
results file prefix. Metric names and file extension are appended to it. *args* and *kwargs* are
passed to :func:`evaluate`. Except for *fwhm* in *kwargs* which is used to filter low
frequencies from the results.
"""
fwhm = kwargs.pop("fwhm") if "fwhm" in kwargs else None
pool = multiprocessing.Pool(processes=multiprocessing.cpu_count())
exec_func = partial(evaluate, *args, **kwargs)
results = pool.map(exec_func, images)
merged = {}
for metric in results[0].keys():
merged[metric] = np.array([result[metric] for result in results])
if fwhm:
# Filter out low frequencies
merged[metric] = filter_data(merged[metric], fwhm=fwhm)
if out_prefix is not None:
path = out_prefix + "_" + metric + ".txt"
np.savetxt(path, merged[metric], fmt="%g")
return merged
def process(
names,
num_images_for_stats=0,
metric_names=None,
out_prefix=None,
fwhm=None,
metrcs_1d_kwargs=None,
blur_fwhm=None,
):
"""Process many files given by *names*. *out_prefix* is the output file prefix where the metric
results will be written to. *fwhm* is used to filter our low frequencies from the results.
*metrics_1d_kwargs* are additional keyword arguments passed to the functions, they are specified
in dictioinary {func_name: kwargs}.
"""
if num_images_for_stats:
if num_images_for_stats == -1:
num_images_for_stats = len(names)
extrema_metrics = {"min": np.min, "max": np.max}
extrema = evaluate_metrics(
names[:num_images_for_stats],
None,
metrics_1d=extrema_metrics,
fwhm=fwhm,
blur_fwhm=blur_fwhm,
)
global_min = np.mean(extrema["min"])
global_max = np.mean(extrema["max"])
else:
global_min = global_max = None
metrics_1d, metrics_2d = make_metrics(metric_names)
return evaluate_metrics(
names,
out_prefix,
metrics_1d=metrics_1d,
metrics_2d=metrics_2d,
global_min=global_min,
global_max=global_max,
fwhm=fwhm,
metrics_1d_kwargs=metrcs_1d_kwargs,
blur_fwhm=blur_fwhm,
)
def main():
args = parse_args()
names = sorted(glob.glob(args.input))
if args.dims == 2:
axis_length = int(np.sqrt(len(names)))
size_str = "{} x {}".format(axis_length, axis_length)
else:
axis_length = len(names)
size_str = str(axis_length)
print("Data size: {}".format(size_str))
kwargs = {"entropy": {"bins": args.entropy_num_bins}}
for key in kwargs.keys():
kwargs["m" + key] = kwargs[key]
st = time.time()
results = process(
names,
num_images_for_stats=args.num_images_for_stats,
metric_names=args.metrics,
fwhm=args.fwhm,
metrcs_1d_kwargs=kwargs,
blur_fwhm=args.blur_fwhm,
)
if args.verbose:
print("Duration: {} s".format(time.time() - st))
x_data = y_data = None
for metric, data in results.iteritems():
if x_data is None:
x_data = construct_range(args.x_from, args.x_to, len(data), unit=args.x_unit)
y_data = construct_range(args.y_from, args.y_to, len(data), unit=args.y_unit)
write(
args.output,
metric,
data,
axis_length,
x_data=x_data,
y_data=y_data,
save_raw=args.save_raw,
save_txt=args.save_txt,
save_plot=args.save_plot,
)
argmax = np.argmax(data)
if args.dims == 2:
argmax = np.unravel_index(argmax, (axis_length, axis_length))
y_argmax = y_data[argmax[0]].magnitude
x_argmax = x_data[argmax[1]].magnitude
retval = (x_argmax, y_argmax)
else:
x_argmax = x_data[argmax].magnitude
retval = x_argmax
print(retval)
def write(
out_dir,
metric,
data,
axis_length,
x_data=None,
y_data=None,
save_raw=False,
save_txt=False,
save_plot=False,
):
out_path = os.path.join(out_dir, metric)
if not os.path.exists(out_dir):
os.makedirs(out_dir, mode=0o755)
if axis_length == len(data):
# 1D
if save_raw:
np.save(out_path + ".npy", data)
if save_plot:
write_1d_plot(out_path, data, metric, x_data=x_data)
else:
reshaped = data.reshape(axis_length, axis_length)
if save_raw:
write_libtiff(out_path + "_raw" + ".tif", reshaped.astype(np.float32))
if save_plot:
write_2d_plot(out_path, reshaped, metric, x_data=x_data, y_data=y_data)
if save_txt:
data = np.array((x_data.magnitude, data))
# Convenient to be read by pgfplots
np.savetxt(out_path + ".txt", data.T, fmt="%g", delimiter="\t", comments="", header="x\ty")
def write_1d_plot(out_path, data, metric, x_data=None):
from matplotlib import pyplot as plt
plt.figure()
if x_data is not None:
plt.plot(x_data.magnitude, data)
plt.xlabel(x_data.units)
else:
plt.plot(data)
plt.title(metric)
plt.grid()
plt.savefig(out_path + ".tif")
plt.close()
def write_2d_plot(out_path, data, metric, x_data=None, y_data=None):
from matplotlib import pyplot as plt, cm
plt.figure()
plt.imshow(data, cmap=cm.gray)
if x_data is not None:
x_from = x_data[0].magnitude
x_to = x_data[-1].magnitude
num_x_ticks = min(data.shape[1], 9)
x_locs = np.linspace(-0.5, data.shape[1] - 0.5, num_x_ticks)
x_labels = np.linspace(x_from, x_to, num_x_ticks)
plt.xticks(x_locs, x_labels)
plt.xlabel(x_data.units)
if y_data is not None:
y_from = y_data[0].magnitude
y_to = y_data[-1].magnitude
num_y_ticks = min(data.shape[0], 9)
y_locs = np.linspace(-0.5, data.shape[0] - 0.5, num_y_ticks)
y_labels = np.linspace(y_from, y_to, num_y_ticks)
plt.yticks(y_locs, y_labels)
plt.ylabel(y_data.units)
plt.title(metric)
plt.savefig(out_path + ".tif")
plt.close()
def construct_range(start, stop, num, unit=""):
start = 0 if start is None else start
stop = num if stop is None else stop
region = np.linspace(start, stop, num=num, endpoint=False)
return q.Quantity(region, unit)
def make_metrics(keys):
"""Buld 1d and 2d metrics dictionaries from *keys*."""
if keys is None:
metrics_1d = METRICS_1D
metrics_2d = METRICS_2D
else:
metrics_1d = {key: METRICS_1D[key] for key in keys if key in METRICS_1D}
metrics_2d = {key: METRICS_2D[key] for key in keys if key in METRICS_2D}
return metrics_1d, metrics_2d
def parse_args():
parser = argparse.ArgumentParser(
description="Evaluate sharpness metrics based on parameter changes in 3D reconstruction"
)
parser.add_argument("input", type=str, help="Input path pattern")
parser.add_argument(
"dims", type=int, choices=(1, 2), help="Number of scanned parameters in the data set"
)
parser.add_argument("--output", type=str, default=".", help="Output directory")
parser.add_argument(
"--metrics",
type=str,
nargs="*",
choices=METRICS_1D.keys() + METRICS_2D.keys(),
help="Metrics to determine (m prefix means -metric)",
)
parser.add_argument("--x-from", type=float, help="X data from")
parser.add_argument("--x-to", type=float, help="X data to")
parser.add_argument("--x-unit", type=str, default="", help="X axis units")
parser.add_argument("--y-from", type=float, help="Y data from")
parser.add_argument("--y-to", type=float, help="Y data to")
parser.add_argument("--y-unit", type=str, default="", help="Y axis units")
parser.add_argument(
"--num-images-for-stats",
type=int,
default=0,
help=(
"If not zero, an "
"image sequence is first read and the mean min and max intensities are "
"used as a global range of values to work on (-1 means read all images)"
),
)
parser.add_argument(
"--fwhm",
type=float,
help="FWHM of 1 - Gauss in real space used to filter out low frequencies.",
)
parser.add_argument(
"--entropy-num-bins",
type=int,
default=256,
help="Number of bins to use for histogram calculation by entropy",
)
parser.add_argument(
"--blur-fwhm", type=float, help="FWHM of the Gaussian blur applied to images"
)
parser.add_argument("--save-raw", action="store_true", help="Store raw data (1D npy, 2D tiff)?")
parser.add_argument("--save-txt", action="store_true", help="Store raw data as text files")
parser.add_argument("--save-plot", action="store_true", help="Store plot data")
parser.add_argument("--verbose", action="store_true", help="Verbose output")
args = parser.parse_args()
if (args.x_from is None) ^ (args.x_to is None):
raise ValueError("Either both x-from and x-to are set or both are not")
if (args.y_from is None) ^ (args.y_to is None):
raise ValueError("Either both y-from and y-to are set or both are not")
return args
if __name__ == "__main__":
main()
tofu-0.12.0/tofu/ez/find_axis_cmd_gen.py 0000664 0000000 0000000 00000014131 14237137211 0020135 0 ustar 00root root 0000000 0000000 #!/bin/python
"""
Created on Apr 6, 2018
@author: gasilos
"""
import glob
import os
import numpy as np
from tofu.ez.evaluate_sharpness import process as process_metrics
from tofu.ez.util import enquote
from tofu.util import get_filenames, read_image, determine_shape
import tifffile
class findCOR_cmds(object):
"""
Generates commands to find the axis of rotation
"""
def __init__(self, fol):
self._fdt_names = fol
def make_inpaths(self, lvl0, flats2, args):
"""
Creates a list of paths to flats/darks/tomo directories
:param lvl0: Root of directory containing flats/darks/tomo
:param flats2: The type of directory: 3 contains flats/darks/tomo 4 contains flats/darks/tomo/flats2
:return: List of paths to the directories containing darks/flats/tomo and flats2 (if used)
"""
indir = []
# If using flats/darks/flats2 in same dir as tomo
if not args.main_config_common_flats_darks:
for i in self._fdt_names[:3]:
indir.append(os.path.join(lvl0, i))
if flats2 - 3:
indir.append(os.path.join(lvl0, self._fdt_names[3]))
return indir
# If using common flats/darks/flats2 across multiple reconstructions
elif args.main_config_common_flats_darks:
indir.append(args.main_config_darks_path)
indir.append(args.main_config_flats_path)
indir.append(os.path.join(lvl0, self._fdt_names[2]))
if args.main_config_flats2_checkbox:
indir.append(args.main_config_flats2_path)
return indir
def find_axis_std(self, ctset, tmpdir, ax_range, p_width, search_row, nviews, args, WH):
indir = self.make_inpaths(ctset[0], ctset[1], args)
image = read_image(get_filenames(indir[2])[0])
cmd = "tofu reco --absorptivity --fix-nan-and-inf --overall-angle 180 --axis-angle-x 0"
cmd += " --darks {} --flats {} --projections {}".format(
indir[0], indir[1], enquote(indir[2])
)
cmd += " --number {}".format(nviews)
if ctset[1] == 4:
cmd += " --flats2 {}".format(indir[3])
out_pattern = os.path.join(tmpdir, "axis-search/sli")
cmd += " --output {}".format(enquote(out_pattern))
cmd += " --x-region={},{},{}".format(int(-p_width / 2), int(p_width / 2), 1)
cmd += " --y-region={},{},{}".format(int(-p_width / 2), int(p_width / 2), 1)
image_height = WH[0]
ax_range_list = ax_range.split(",")
range_min = ax_range_list[0]
range_max = ax_range_list[1]
step = ax_range_list[2]
range_string = str(range_min) + "," + str(range_max) + "," + str(step)
cmd += " --region={}".format(range_string)
res = [float(num) for num in ax_range.split(",")]
cmd += " --output-bytes-per-file 0"
cmd += ' --z-parameter center-position-x'
cmd += ' --z {}'.format(args.main_cor_search_row_start - int(image_height/2))
print(cmd)
os.system(cmd)
points, maximum = evaluate_images_simp(out_pattern + "*.tif", "msag")
return res[0] + res[2] * maximum
def find_axis_corr(self, ctset, vcrop, y, height, multipage, args):
indir = self.make_inpaths(ctset[0], ctset[1], args)
"""Use correlation to estimate center of rotation for tomography."""
from scipy.signal import fftconvolve
def flat_correct(flat, radio):
nonzero = np.where(radio != 0)
result = np.zeros_like(radio)
result[nonzero] = flat[nonzero] / radio[nonzero]
# log(1) = 0
result[result <= 0] = 1
return np.log(result)
if multipage:
with tifffile.TiffFile(get_filenames(indir[2])[0]) as tif:
first = tif.pages[0].asarray().astype(np.float)
with tifffile.TiffFile(get_filenames(indir[2])[-1]) as tif:
last = tif.pages[-1].asarray().astype(np.float)
with tifffile.TiffFile(get_filenames(indir[0])[-1]) as tif:
dark = tif.pages[-1].asarray().astype(np.float)
with tifffile.TiffFile(get_filenames(indir[1])[0]) as tif:
flat1 = tif.pages[-1].asarray().astype(np.float) - dark
else:
first = read_image(get_filenames(indir[2])[0]).astype(np.float)
last = read_image(get_filenames(indir[2])[-1]).astype(np.float)
dark = read_image(get_filenames(indir[0])[-1]).astype(np.float)
flat1 = read_image(get_filenames(indir[1])[-1]) - dark
first = flat_correct(flat1, first - dark)
if ctset[1] == 4:
if multipage:
with tifffile.TiffFile(get_filenames(indir[3])[0]) as tif:
flat2 = tif.pages[-1].asarray().astype(np.float) - dark
else:
flat2 = read_image(get_filenames(indir[3])[-1]) - dark
last = flat_correct(flat2, last - dark)
else:
last = flat_correct(flat1, last - dark)
if vcrop:
y_region = slice(y, min(y + height, first.shape[0]), 1)
first = first[y_region, :]
last = last[y_region, :]
width = first.shape[1]
first = first - first.mean()
last = last - last.mean()
conv = fftconvolve(first, last[::-1, :], mode="same")
center = np.unravel_index(conv.argmax(), conv.shape)[1]
return (width / 2.0 + center) / 2.0
# Find midpoint width of image and return its value
def find_axis_image_midpoint(self, ctset, multipage, height_width):
return height_width[1] // 2
def evaluate_images_simp(
input_pattern,
metric,
num_images_for_stats=0,
out_prefix=None,
fwhm=None,
blur_fwhm=None,
verbose=False,
):
# simplified version of original evaluate_images function
# from Tomas's optimize_parameters script
names = sorted(glob.glob(input_pattern))
res = process_metrics(
names,
num_images_for_stats=num_images_for_stats,
metric_names=(metric,),
out_prefix=out_prefix,
fwhm=fwhm,
blur_fwhm=blur_fwhm,
)[metric]
return res, np.argmax(res)
tofu-0.12.0/tofu/ez/image_read_write.py 0000664 0000000 0000000 00000016334 14237137211 0020013 0 ustar 00root root 0000000 0000000 import os, glob
import numpy as np
import tifffile
from tifffile import imread, imwrite
class InvalidDataSetError(Exception):
"""
Error to be raised on attempt to read data from empty or non-existing data set
"""
def validate_files_path(files_path: str, supported_file_types: list) -> bool:
"""
Validates specified path
:param supported_file_types: List of supported extensions
:param files_path: Path to validate
:return: True if path exists and contains at least one file of supported type, else False
"""
try:
valid_files_list = get_valid_files_list(
files_path=files_path, supported_file_types=supported_file_types
)
except InvalidDataSetError:
return False
return len(valid_files_list) > 0
def get_valid_files_list(files_path: str, supported_file_types: list) -> list:
"""
Get the list of files of supported type in directory
:param supported_file_types: List of supported extensions
:param files_path: Path to directory with files
:return: List of full paths to files
"""
# Check if directory exists
if not os.path.exists(files_path):
raise InvalidDataSetError(f"No such directory: {files_path}")
files_list = os.listdir(files_path)
valid_files_list = [
os.path.join(files_path, file_name)
for file_name in files_list
if os.path.splitext(file_name)[1] in supported_file_types
]
return sorted(valid_files_list)
def read_image(image_file_path: str, data_type=np.float32) -> np.ndarray:
"""
Reads image file to numpy.ndarray of specified type
:param data_type: Data type to store the image
:param image_file_path: Full path to image to read
:return:
"""
return imread(image_file_path).astype(dtype=data_type)
def write_image(image: np.ndarray, target_directory: str, target_name: str, data_type=np.float32):
"""
Writes image data to file
:param image: Image data
:param target_directory: Path to directory to write image
:param target_name: Target image file name
:param data_type: Data type to write the image
:return:
"""
os.makedirs(target_directory, exist_ok=True)
data_file_path = os.path.join(target_directory, target_name)
imwrite(data_file_path, data=image.astype(dtype=data_type))
def write_all_images(tiff_arr: np.ndarray, target_directory: str, data_type=np.float32):
"""
Writes all images in numpy array as individual files in a directory
:param tiff_arr: Array containing images
:param target_directory: Path to directory to write images
:param data_type: Data type to write the images
:return:
"""
print("Writing Images to Directory")
# We determine the number of leading zeros to append.
# Find number of digits from number of files to write, then add +1 number of leading zeros
index = 1
length_str = str(tiff_arr.shape[0])
num_digits = len(length_str)
for image in tiff_arr:
write_image(
image, target_directory, "Image_" + str(index).zfill(num_digits + 1) + ".tif", data_type
)
index += 1
print("Finished Writing Images to Directory")
def read_all_images(
image_files_path: str, supported_image_types: list, data_type=np.float32
) -> np.ndarray:
"""
Reads all images of the supported type from specified directory
:param supported_image_types: List of supported extensions
:param image_files_path: Path to directory with images
:param data_type: Data type to store the images
:return: 3-dimensional numpy.ndarray of specified type, first index being image index
"""
valid_files_list = get_valid_files_list(
files_path=image_files_path, supported_file_types=supported_image_types
)
if len(valid_files_list) == 0:
raise InvalidDataSetError(
f"Directory {image_files_path} "
f"does not contain files of supported types {supported_image_types}"
)
data_array = imread(valid_files_list).astype(dtype=data_type)
return np.array(data_array)
"""Image readers for convenient work with multi-page image sequences."""
""" TAKEN STRAIT FROM ufo-kit/Concert (python 2 version) with permission of Tomas"""
class FileSequenceReader(object):
"""Image sequence reader optimized for reading consecutive images. One multi-page image file is
not closed after an image is read so that it does not have to be re-opened for reading the next
image. The :func:`.close` function must be called explicitly in order to close the last opened
image.
"""
def __init__(self, file_prefix, ext=''):
if os.path.isdir(file_prefix):
file_prefix = os.path.join(file_prefix, '*' + ext)
self._filenames = sorted(glob.glob(file_prefix))
if not self._filenames:
raise SequenceReaderError("No files matching `{}' found".format(file_prefix))
self._lengths = {}
self._file = None
self._filename = None
def __enter__(self):
return self
def __exit__(self, exc_type, exc_value, traceback):
self.close()
@property
def num_images(self):
num = 0
for filename in self._filenames:
num += self._get_num_images_in_file(filename)
return num
def read(self, index):
if index < 0:
# Enables negative indexing
index += self.num_images
file_index = 0
while index >= 0:
if file_index >= len(self._filenames):
raise SequenceReaderError('image index greater than sequence length')
index -= self._get_num_images_in_file(self._filenames[file_index])
file_index += 1
file_index -= 1
index += self._lengths[self._filenames[file_index]]
self._open(self._filenames[file_index])
return self._read_real(index)
def _open(self, filename):
if self._filename != filename:
if self._filename:
self.close()
self._file = self._open_real(filename)
self._filename = filename
def close(self):
if self._filename:
self._close_real()
self._file = None
self._filename = None
def _get_num_images_in_file(self, filename):
if filename not in self._lengths:
self._open(filename)
self._lengths[filename] = self._get_num_images_in_file_real()
return self._lengths[filename]
def _open_real(self, filename):
"""Returns an open file."""
raise NotImplementedError
def _close_real(self, filename):
raise NotImplementedError
def _get_num_images_in_file_real(self):
raise NotImplementedError
def _read_real(self, index):
raise NotImplementedError
class TiffSequenceReader(FileSequenceReader):
def __init__(self, file_prefix, ext='.tif'):
super(TiffSequenceReader, self).__init__(file_prefix, ext=ext)
def _open_real(self, filename):
import tifffile
return tifffile.TiffFile(filename)
def _close_real(self):
self._file.close()
def _get_num_images_in_file_real(self):
return len(self._file.pages)
def _read_real(self, index):
return self._file.pages[index].asarray()
class SequenceReaderError(Exception):
pass tofu-0.12.0/tofu/ez/main.py 0000664 0000000 0000000 00000040361 14237137211 0015445 0 ustar 00root root 0000000 0000000 """
Created on Apr 5, 2018
@author: sergei gasilov
"""
import logging
import os
from tofu.util import get_filenames, read_image
import warnings
warnings.filterwarnings("ignore")
import time
#from shutil import rmtree
from tofu.ez.ctdir_walker import WalkCTdirs
from tofu.ez.tofu_cmd_gen import tofu_cmds
from tofu.ez.ufo_cmd_gen import ufo_cmds
from tofu.ez.find_axis_cmd_gen import findCOR_cmds
from tofu.ez.util import *
# from tofu.util import get_filenames
LOG = logging.getLogger(__name__)
def get_CTdirs_list(inpath, fdt_names, args):
"""
Determines whether directories containing CT data are valid.
Returns list of subdirectories with valid CT data
:param inpath: Path to the CT directory containing subdirectories with flats/darks/tomo (and flats2 if used)
:param fdt_names: Names of the directories which store flats/darks/tomo (and flats2 if used)
:param args: Arguments from the GUI
:return: W.ctsets: List of "good" CTSets and W.lvl0: Path to root of CT sets
"""
# Constructor call to create WalkCTDirs object
W = WalkCTdirs(inpath, fdt_names, args)
# Find any directories containing "tomo" directory
W.findCTdirs()
# If "Use common flats/darks across multiple experiments" is enabled
if args.main_config_common_flats_darks:
logging.debug("Use common darks/flats")
logging.debug("Path to darks: " + str(args.main_config_darks_path))
logging.debug("Path to flats: " + str(args.main_config_flats_path))
logging.debug("Path to flats2: " + str(args.main_config_flats2_path))
logging.debug("Use flats2: " + str(args.main_config_flats2_checkbox))
# Determine whether paths to common flats/darks/flats2 exist
if not W.checkCommonFDT():
print("Invalid path to common flats/darks")
return W.ctsets, W.lvl0
else:
LOG.debug("Paths to common flats/darks exist")
# Check whether directories contain only .tif files
if not W.checkCommonFDTFiles():
return W.ctsets, W.lvl0
else:
# Sort good bad sets
W.SortBadGoodSets()
return W.ctsets, W.lvl0
# If "Use common flats/darks across multiple experiments" is not enabled
else:
LOG.debug("Use flats/darks in same directory as tomo")
# Check if common flats/darks/flats2 are type 3 or 4
W.checkCTdirs()
# Need to check if common flats/darks contain only .tif files
W.checkCTfiles()
W.SortBadGoodSets()
return W.ctsets, W.lvl0
def frmt_ufo_cmds(cmds, ctset, out_pattern, ax, args, Tofu, Ufo, FindCOR, nviews, WH):
"""formats list of processing commands for a CT set"""
# two helper variables to mark that PR/FFC has been done at some step
swiFFC = True # FFC is always required required
swiPR = args.main_pr_phase_retrieval # PR is an optional operation
####### PREPROCESSING #########
flat_file_for_mask = os.path.join(args.main_config_temp_dir, 'flat.tif')
if args.main_filters_remove_spots:
if not args.main_config_common_flats_darks:
flatdir = os.path.join(ctset[0], Tofu._fdt_names[1])
elif args.main_config_common_flats_darks:
flatdir = args.main_config_flats_path
cmd = make_copy_of_flat(flatdir, flat_file_for_mask, args.main_config_dry_run)
cmds.append(cmd)
if args.main_config_preprocess:
cmds.append('echo " - Applying filter(s) to images "')
cmds_prepro = Ufo.get_pre_cmd(ctset, args.main_config_preprocess_command,
args.main_config_temp_dir,
args.main_config_dry_run, args)
cmds.extend(cmds_prepro)
# reset location of input data
ctset = (args.main_config_temp_dir, ctset[1])
###################################################
if args.main_filters_remove_spots: # generate commands to remove sci. spots from projections
cmds.append('echo " - Flat-correcting and removing large spots"')
cmds_inpaint = Ufo.get_inp_cmd(ctset, args.main_config_temp_dir, args, WH[0], nviews, flat_file_for_mask)
# reset location of input data
ctset = (args.main_config_temp_dir, ctset[1])
cmds.extend(cmds_inpaint)
swiFFC = False # no need to do FFC anymore
######## PHASE-RETRIEVAL #######
# Do PR separately if sinograms must be generate or if vertical ROI is defined
if args.main_pr_phase_retrieval and args.main_filters_ring_removal: # or (args.main_pr_phase_retrieval and args.main_region_select_rows):
if swiFFC: # we still need need flat correction #Inpaint No
cmds.append('echo " - Phase retrieval with flat-correction"')
if args.advanced_ffc_sinFFC:
cmds.append(Tofu.get_pr_sinFFC_cmd(ctset, args, nviews, WH[0]))
cmds.append(Tofu.get_pr_tofu_cmd_sinFFC(ctset, args, nviews, WH))
elif not args.advanced_ffc_sinFFC:
cmds.append(Tofu.get_pr_tofu_cmd(ctset, args, nviews, WH[0]))
else: # Inpaint Yes
cmds.append('echo " - Phase retrieval from flat-corrected projections"')
cmds.extend(Ufo.get_pr_ufo_cmd(args, nviews, WH))
swiPR = False # no need to do PR anymore
swiFFC = False # no need to do FFC anymore
# if args.PR and args.vcrop: # have to reset location of input data
# ctset = (args.tmpdir, ctset[1])
################# RING REMOVAL #######################
if args.main_filters_ring_removal:
# Generate sinograms first
if swiFFC: # we still need to do flat-field correction
if args.advanced_ffc_sinFFC:
# Create flat corrected images using sinFFC
cmds.append(Tofu.get_sinFFC_cmd(ctset, args, nviews, WH[0]))
# Feed the flat corrected images to sino gram generation
cmds.append(Tofu.get_sinos_noffc_cmd(ctset[0], args.main_config_temp_dir, args, nviews, WH))
elif not args.advanced_ffc_sinFFC:
cmds.append('echo " - Make sinograms with flat-correction"')
cmds.append(Tofu.get_sinos_ffc_cmd(ctset, args.main_config_temp_dir, args, nviews, WH))
else: # we do not need flat-field correction
cmds.append('echo " - Make sinograms without flat-correction"')
cmds.append(Tofu.get_sinos_noffc_cmd(ctset[0], args.main_config_temp_dir, args, nviews, WH))
swiFFC = False
# Filter sinograms
if args.main_filters_ring_removal_ufo_lpf:
if args.main_filters_ring_removal_ufo_lpf_1d_or_2d:
cmds.append('echo " - Ring removal - ufo 1d stripes filter"')
cmds.append(Ufo.get_filter1d_sinos_cmd(args.main_config_temp_dir,
args.main_filters_ring_removal_ufo_lpf_sigma_horizontal, nviews))
else:
cmds.append('echo " - Ring removal - ufo 2d stripes filter"')
cmds.append(Ufo.get_filter2d_sinos_cmd(args.main_config_temp_dir, \
args.main_filters_ring_removal_ufo_lpf_sigma_horizontal,
args.main_filters_ring_removal_ufo_lpf_sigma_vertical,
nviews, WH[1]))
else:
cmds.append('echo " - Ring removal - sarepy filter(s)"')
# note - calling an external program, not an ufo-kit script
tmp = os.path.dirname(os.path.abspath(__file__))
path_to_filt = os.path.join(tmp, "RR_external.py")
if os.path.isfile(path_to_filt):
tmp = os.path.join(args.main_config_temp_dir, "sinos")
cmdtmp = 'python {} --sinos {} --mws {} --mws2 {} --snr {} --sort_only {}' \
.format(path_to_filt, tmp,
args.main_filters_ring_removal_sarepy_window_size,
args.main_filters_ring_removal_sarepy_window,
args.main_filters_ring_removal_sarepy_SNR,
int(not args.main_filters_ring_removal_sarepy_wide))
cmds.append(cmdtmp)
else:
cmds.append('echo "Omitting RR because file with filter does not exist"')
if not args.main_config_keep_temp:
cmds.append("rm -rf {}".format(os.path.join(args.main_config_temp_dir, "sinos")))
# Convert filtered sinograms back to projections
cmds.append('echo " - Generating proj from filtered sinograms"')
cmds.append(Tofu.get_sinos2proj_cmd(args, WH[0]))
# reset location of input data
ctset = (args.main_config_temp_dir, ctset[1])
# Finally - call to tofu reco
cmds.append('echo " - CT with axis {}; ffc:{}, PR:{}"'.format(ax, swiFFC, swiPR))
if args.advanced_ffc_sinFFC and swiFFC:
cmds.append(Tofu.get_sinFFC_cmd(ctset, args, nviews, WH[0]))
cmds.append(
Tofu.get_reco_cmd_sinFFC(ctset, out_pattern, ax, args, nviews, WH, swiFFC, swiPR)
)
else: # If not using sinFFC
cmds.append(Tofu.get_reco_cmd(ctset, out_pattern, ax, args, nviews, WH, swiFFC, swiPR))
return nviews, WH
def fmt_nlmdn_ufo_cmd(inpath: str, outpath: str, args):
"""
:param inp: Path to input directory before NLMDN applied
:param out: Path to output directory after NLMDN applied
:param args: List of args
:return:
"""
cmd = 'ufo-launch read path={}'.format(inpath)
cmd += ' ! non-local-means patch-radius={}'.format(args.advanced_nlmdn_patch_radius)
cmd += ' search-radius={}'.format(args.advanced_nlmdn_sim_search_radius)
cmd += ' h={}'.format(args.advanced_nlmdn_smoothing_control)
cmd += ' sigma={}'.format(args.advanced_nlmdn_noise_std)
cmd += ' window={}'.format(args.advanced_nlmdn_window)
cmd += ' fast={}'.format(args.advanced_nlmdn_fast)
cmd += ' estimate-sigma={}'.format(args.advanced_nlmdn_estimate_sigma)
cmd += ' ! write filename={}'.format(enquote(outpath))
if not args.advanced_nlmdn_save_bigtiff:
cmd += " bytes-per-file=0 tiff-bigtiff=False"
return cmd
def execute_reconstruction(args, fdt_names):
# array with the list of commands
cmds = []
# clean temporary directory or create if it doesn't exist
if not os.path.exists(args.main_config_temp_dir):
os.makedirs(args.main_config_temp_dir)
# else:
# clean_tmp_dirs(args.main_config_temp_dir, fdt_names)
if args.main_region_clip_histogram:
if args.main_region_histogram_min > args.main_region_histogram_max:
raise ValueError('hmin must be smaller than hmax to convert to 8bit without contrast inversion')
# get list of all good CT directories to be reconstructed
print('*********** Analyzing input directory ************')
W, lvl0 = get_CTdirs_list(args.main_config_input_dir, fdt_names, args)
# W is an array of tuples (path, type)
# get list of already reconstructed sets
recd_sets = findSlicesDirs(args.main_config_output_dir)
# initialize command generators
FindCOR = findCOR_cmds(fdt_names)
Tofu = tofu_cmds(fdt_names)
Ufo = ufo_cmds(fdt_names)
# populate list of reconstruction commands
print("*********** AXIS INFO ************")
for i, ctset in enumerate(W):
# ctset is a tuple containing a path and a type (3 or 4)
if not already_recd(ctset[0], lvl0, recd_sets):
# determine initial number of projections and their shape
path2proj = os.path.join(ctset[0], fdt_names[2])
nviews, WH, multipage = get_dims(path2proj)
# If args.main_cor_axis_search_method == 4 then bypass axis search and use image midpoint
if args.main_cor_axis_search_method != 4:
if (args.main_region_select_rows and bad_vert_ROI(multipage, path2proj,
args.main_region_first_row, args.main_region_number_rows)):
print('{}\t{}'.format('CTset:', ctset[0]))
print('{:>30}\t{}'.format('Axis:', 'na'))
print('Vertical ROI does not contain any rows.')
print("{:>30}\t{}, dimensions: {}".format("Number of projections:", nviews, WH))
continue
# Find axis of rotation using auto: correlate first/last projections
if args.main_cor_axis_search_method == 1:
ax = FindCOR.find_axis_corr(ctset,
args.main_region_select_rows,
args.main_region_first_row,
args.main_region_number_rows, multipage, args)
# Find axis of rotation using auto: minimize STD of a slice
elif args.main_cor_axis_search_method == 2:
cmds.append("echo \"Cleaning axis-search in tmp directory\"")
os.system('rm -rf {}'.format(os.path.join(args.main_config_temp_dir, 'axis-search')))
ax = FindCOR.find_axis_std(ctset,
args.main_config_temp_dir,
args.main_cor_axis_search_interval,
args.main_cor_recon_patch_size,
args.main_cor_search_row_start,
nviews, args, WH)
else:
ax = args.main_cor_axis_column + i * args.main_cor_axis_increment_step
# If args.main_cor_axis_search_method == 4 then bypass axis search and use image midpoint
elif args.main_cor_axis_search_method == 4:
ax = FindCOR.find_axis_image_midpoint(ctset, multipage, WH)
print("Bypassing axis search and using image midpoint: {}".format(ax))
setid = ctset[0][len(lvl0) + 1:]
out_pattern = os.path.join(args.main_config_output_dir, setid, 'sli/sli')
cmds.append('echo ">>>>> PROCESSING {}"'.format(setid))
# rm files in temporary directory first of all to
# format paths correctly and to avoid problems
# when reconstructing ct sets with variable number of rows or projections
cmds.append('echo "Cleaning temporary directory"'.format(setid))
clean_tmp_dirs(args.main_config_temp_dir, fdt_names)
# call function which formats commands for this data set
nviews, WH = frmt_ufo_cmds(cmds, ctset, out_pattern, \
ax, args, Tofu, Ufo, FindCOR, nviews, WH)
save_params(args, setid, ax, nviews, WH)
print('{}\t{}'.format('CTset:', ctset[0]))
print('{:>30}\t{}'.format('Axis:', ax))
print("{:>30}\t{}, dimensions: {}".format("Number of projections:", nviews, WH))
# tmp = "Number of projections: {}, dimensions: {}".format(nviews, WH)
# cmds.append("echo \"{}\"".format(tmp))
if args.advanced_nlmdn_apply_after_reco:
logging.debug("Using Non-Local Means Denoising")
nlmdn_input = out_pattern
head, tail = os.path.split(out_pattern)
slidir = os.path.dirname(head)
nlmdn_output = os.path.join(slidir+"-nlmdn", "sli-nlmdn-%04i.tif")
cmds.append(fmt_nlmdn_ufo_cmd(slidir, nlmdn_output, args))
else:
print("{} has been already reconstructed".format(ctset[0]))
# execute commands = start reconstruction
start = time.time()
print("*********** PROCESSING ************")
for cmd in cmds:
if not args.main_config_dry_run:
os.system(cmd)
else:
print(cmd)
if not args.main_config_keep_temp:
clean_tmp_dirs(args.main_config_temp_dir, fdt_names)
print("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx")
print("*** Done. Total processing time {} sec.".format(int(time.time() - start)))
print("*** Waiting for the next job...........")
# cmnds, axes = get_ufo_cmnds(W, tmpdir, recodir, fol, axes = None, dryrun = False)
def already_recd(ctset, indir, recd_sets):
x = False
if ctset[len(indir) + 1 :] in recd_sets:
x = True
return x
def findSlicesDirs(lvl0):
recd_sets = []
for root, dirs, files in os.walk(lvl0):
for name in dirs:
if name == "sli":
recd_sets.append(root[len(lvl0) + 1 :])
return recd_sets
tofu-0.12.0/tofu/ez/main_nlm.py 0000664 0000000 0000000 00000002337 14237137211 0016314 0 ustar 00root root 0000000 0000000 """
Created on Dec 1, 2020
@author: sergei gasilov
"""
import os
import warnings
from tofu.ez.ctdir_walker import WalkCTdirs
from tofu.ez.util import *
from tofu.util import get_filenames
from tofu.util import get_filenames, read_image
from tofu.ez.util import enquote
warnings.filterwarnings("ignore")
def fmt_ufo_cmd(inp, out, args):
cmd = "ufo-launch read path={}".format(inp)
cmd += " ! non-local-means patch-radius={}".format(args.patch_r)
cmd += " search-radius={}".format(args.search_r)
cmd += " h={}".format(args.h)
cmd += " sigma={}".format(args.sig)
cmd += " window={}".format(args.w)
cmd += " fast={}".format(args.fast)
cmd += " estimate-sigma={}".format(args.autosig)
cmd += " ! write filename={}".format(enquote(out))
if not args.bigtif:
cmd += " bytes-per-file=0 tiff-bigtiff=False"
return cmd
def main_tk(args):
if args.input_is_file:
out_pattern = args.outdir
else:
if not os.path.exists(args.outdir):
os.makedirs(args.outdir)
out_pattern = os.path.join(args.outdir, "im-nlmfilt-%05i.tif")
cmd = fmt_ufo_cmd(args.indir, out_pattern, args)
if args.dryrun:
print(cmd)
else:
os.system(cmd)
return 0
tofu-0.12.0/tofu/ez/params.py 0000664 0000000 0000000 00000000562 14237137211 0016003 0 ustar 00root root 0000000 0000000 # This file is used to share params as a global variable
import yaml
#TODO Make good structure to store parameters
# similar to tofu? and
# use tofu's structure for existing reco params
params = {}
def save_parameters(params, file_path):
file_out = open(file_path, 'w')
yaml.dump(params, file_out)
print("Parameters file saved at: " + str(file_path))
tofu-0.12.0/tofu/ez/tofu_cmd_gen.py 0000664 0000000 0000000 00000053741 14237137211 0017160 0 ustar 00root root 0000000 0000000 #!/bin/python
"""
Created on Apr 6, 2018
@author: gasilos
"""
import os
import numpy as np
from tofu.ez.ufo_cmd_gen import fmt_in_out_path
class tofu_cmds(object):
"""
Generates partially formatted ufo-launch and tofu commands
Parameters are included in the string; pathnames must be added
"""
def __init__(self, fol):
self._fdt_names = fol
def make_inpaths(self, lvl0, flats2, args):
"""
Creates a list of paths to flats/darks/tomo directories
:param lvl0: Root of directory containing flats/darks/tomo
:param flats2: The type of directory: 3 contains flats/darks/tomo 4 contains flats/darks/tomo/flats2
:return: List of paths to the directories containing darks/flats/tomo and flats2 (if used)
"""
indir = []
# If using flats/darks/flats2 in same dir as tomo
if not args.main_config_common_flats_darks:
for i in self._fdt_names[:3]:
indir.append(os.path.join(lvl0, i))
if flats2 - 3:
indir.append(os.path.join(lvl0, self._fdt_names[3]))
return indir
# If using common flats/darks/flats2 across multiple reconstructions
elif args.main_config_common_flats_darks:
indir.append(args.main_config_darks_path)
indir.append(args.main_config_flats_path)
indir.append(os.path.join(lvl0, self._fdt_names[2]))
if args.main_config_flats2_checkbox:
indir.append(args.main_config_flats2_path)
return indir
def check_lamino(self, cmd, args):
cmd += 'tofu reco'
if not args.advanced_advtofu_lamino_angle == '':
cmd += ' --axis-angle-x {}'.format(args.advanced_advtofu_lamino_angle)
if not args.advanced_adv_tofu_z_axis_rotation == '':
cmd += ' --overall-angle {}'.format(args.advanced_adv_tofu_z_axis_rotation)
if not args.advanced_advtofu_center_position_z == '':
cmd += ' --center-position-z {}'.format(args.advanced_advtofu_center_position_z)
if not args.advanced_advtofu_y_axis_rotation == '':
cmd += ' --axis-angle-y {}'.format(args.advanced_advtofu_y_axis_rotation)
return cmd
def check_8bit(self, cmd, gray256, bit, hmin, hmax):
if gray256:
cmd += " --output-bitdepth {}".format(bit)
# cmd += " --output-minimum \" {}\" --output-maximum \" {}\""\
# .format(hmin, hmax)
cmd += ' --output-minimum " {}" --output-maximum " {}"'.format(hmin, hmax)
return cmd
def check_vcrop(self, cmd, vcrop, y, yheight, ystep, ori_height):
if vcrop:
cmd += " --y {} --height {} --y-step {}".format(y, yheight, ystep)
else:
cmd += " --height {}".format(ori_height)
return cmd
def check_bigtif(self, cmd, swi):
if not swi:
cmd += " --output-bytes-per-file 0"
return cmd
def get_1step_ct_cmd(self, ctset, out_pattern, ax, args, nviews, WH):
# direct CT reconstruction from input dir to output dir;
# or CT reconstruction after preprocessing only
indir = self.make_inpaths(ctset[0], ctset[1], args)
# correct location of proj folder in case if prepro was done
in_proj_dir, quatsch = fmt_in_out_path(args.main_config_temp_dir,
ctset[0], self._fdt_names[2], False)
indir[2] = os.path.join(os.path.split(indir[2])[0], os.path.split(in_proj_dir)[1])
# format command
cmd = "tofu tomo --absorptivity --fix-nan-and-inf"
cmd += " --darks {} --flats {} --projections {}".format(indir[0], indir[1], indir[2])
if ctset[1] == 4: # must be equivalent to len(indir)>3
cmd += " --flats2 {}".format(indir[3])
cmd += " --output {}".format(out_pattern)
cmd += " --axis {}".format(ax)
cmd += " --offset {}".format(args.main_region_rotate_volume_clock)
cmd += " --number {}".format(nviews)
if args.step > 0.0:
cmd += ' --angle {}'.format(args.step)
cmd = self.check_vcrop(cmd, args.main_region_select_rows,
args.main_region_first_row,
args.main_region_number_rows,
args.main_region_nth_row, WH[0])
cmd = self.check_8bit(cmd, args.main_region_clip_histogram,
args.main_region_bit_depth,
args.main_region_histogram_min,
args.main_region_histogram_max)
cmd = self.check_bigtif(cmd, args.main_config_save_multipage_tiff)
return cmd
def get_ct_proj_cmd(self, out_pattern, ax, args, nviews, WH):
# CT reconstruction from pre-processed and flat-corrected projections
in_proj_dir, quatsch = fmt_in_out_path(
args.main_config_temp_dir, "obsolete;if-you-need-fix-it", self._fdt_names[2], False
)
cmd = "tofu tomo --projections {}".format(in_proj_dir)
cmd += " --output {}".format(out_pattern)
cmd += " --axis {}".format(ax)
cmd += " --offset {}".format(args.main_region_rotate_volume_clock)
cmd += " --number {}".format(nviews)
if args.step > 0.0:
cmd += ' --angle {}'.format(args.step)
cmd = self.check_vcrop(cmd, args.main_region_select_rows,
args.main_region_first_row,
args.main_region_number_rows,
args.main_region_nth_row, WH[0])
cmd = self.check_8bit(cmd, args.main_region_clip_histogram,
args.main_region_bit_depth,
args.main_region_histogram_min,
args.main_region_histogram_max)
cmd = self.check_bigtif(cmd, args.main_config_save_multipage_tiff)
return cmd
def get_ct_sin_cmd(self, out_pattern, ax, args, nviews, WH):
sinos_dir = os.path.join(args.main_config_temp_dir, 'sinos-filt')
cmd = 'tofu tomo --sinograms {}'.format(sinos_dir)
cmd += ' --output {}'.format(out_pattern)
cmd += ' --axis {}'.format(ax)
cmd += ' --offset {}'.format(args.main_region_rotate_volume_clock)
if args.main_region_select_rows:
cmd += ' --number {}'.format(int(args.main_region_number_rows / args.main_region_nth_row))
else:
cmd += " --number {}".format(WH[0])
cmd += " --height {}".format(nviews)
if args.step > 0.0:
cmd += ' --angle {}'.format(args.step)
cmd = self.check_8bit(cmd, args.main_region_clip_histogram,
args.main_region_bit_depth,
args.main_region_histogram_min,
args.main_region_histogram_max)
cmd = self.check_bigtif(cmd, args.main_config_save_multipage_tiff)
return cmd
def get_sinos_ffc_cmd(self, ctset, tmpdir, args, nviews, WH):
indir = self.make_inpaths(ctset[0], ctset[1], args)
in_proj_dir, out_pattern = fmt_in_out_path(args.main_config_temp_dir,
ctset[0], self._fdt_names[2], False)
cmd = 'tofu sinos --absorptivity --fix-nan-and-inf'
cmd += ' --darks {} --flats {} '.format(indir[0], indir[1])
if ctset[1] == 4:
cmd += " --flats2 {}".format(indir[3])
cmd += " --projections {}".format(in_proj_dir)
cmd += " --output {}".format(os.path.join(tmpdir, "sinos/sin-%04i.tif"))
cmd += " --number {}".format(nviews)
cmd = self.check_vcrop(cmd, args.main_region_select_rows,
args.main_region_first_row,
args.main_region_number_rows,
args.main_region_nth_row, WH[0])
if not args.main_filters_ring_removal_ufo_lpf:
# because second RR algorithm does not know how to work with multipage tiffs
cmd += " --output-bytes-per-file 0"
if not args.advanced_advtofu_aux_ffc_dark_scale == "":
cmd += ' --dark-scale {}'.format(args.advanced_advtofu_aux_ffc_dark_scale)
if not args.advanced_advtofu_aux_ffc_flat_scale == "":
cmd += ' --flat-scale {}'.format(args.advanced_advtofu_aux_ffc_flat_scale)
return cmd
def get_sinos_noffc_cmd(self, ctsetpath, tmpdir, args, nviews, WH):
in_proj_dir, out_pattern = fmt_in_out_path(
args.main_config_temp_dir, ctsetpath, self._fdt_names[2], False
)
cmd = "tofu sinos"
cmd += " --projections {}".format(in_proj_dir)
cmd += " --output {}".format(os.path.join(tmpdir, "sinos/sin-%04i.tif"))
cmd += " --number {}".format(nviews)
cmd = self.check_vcrop(cmd, args.main_region_select_rows,
args.main_region_first_row,
args.main_region_number_rows,
args.main_region_nth_row,
WH[0])
if not args.main_filters_ring_removal_ufo_lpf:
# because second RR algorithm does not know how to work with multipage tiffs
cmd += " --output-bytes-per-file 0"
return cmd
def get_sinos2proj_cmd(self, args, proj_height):
quatsch, out_pattern = fmt_in_out_path(args.main_config_temp_dir, 'quatsch', self._fdt_names[2], True)
cmd = 'tofu sinos'
cmd += ' --projections {}'.format(os.path.join(args.main_config_temp_dir, 'sinos-filt'))
cmd += ' --output {}'.format(out_pattern)
if not args.main_region_select_rows:
cmd += ' --number {}'.format(proj_height)
else:
cmd += ' --number {}'.format(int(args.main_region_number_rows / args.main_region_nth_row))
return cmd
def get_sinFFC_cmd(self, ctset, args, nviews, n):
indir = self.make_inpaths(ctset[0], ctset[1], args)
in_proj_dir, out_pattern = fmt_in_out_path(args.main_config_temp_dir,
ctset[0], self._fdt_names[2])
cmd = 'bmit_sin --fix-nan'
cmd += ' --darks {} --flats {} --projections {}'.format(indir[0], indir[1], in_proj_dir)
if ctset[1] == 4:
cmd += ' --flats2 {}'.format(indir[3])
cmd += ' --output {}'.format(os.path.dirname(out_pattern))
cmd += ' --method {}'.format(args.advanced_ffc_method)
cmd += ' --multiprocessing'
cmd += ' --eigen-pco-repetitions {}'.format(args.advanced_ffc_eigen_pco_reps)
cmd += ' --eigen-pco-downsample {}'.format(args.advanced_ffc_eigen_pco_downsample)
cmd += ' --downsample {}'.format(args.advanced_ffc_downsample)
return cmd
def get_pr_sinFFC_cmd(self, ctset, args, nviews, n):
indir = self.make_inpaths(ctset[0], ctset[1], args)
in_proj_dir, out_pattern = fmt_in_out_path(
args.main_config_temp_dir, ctset[0], self._fdt_names[2])
cmd = 'bmit_sin --fix-nan'
cmd += ' --darks {} --flats {} --projections {}'.format(indir[0], indir[1], in_proj_dir)
if ctset[1] == 4:
cmd += ' --flats2 {}'.format(indir[3])
cmd += ' --output {}'.format(os.path.dirname(out_pattern))
cmd += ' --method {}'.format(args.advanced_ffc_method)
cmd += ' --multiprocessing'
cmd += ' --eigen-pco-repetitions {}'.format(args.advanced_ffc_eigen_pco_reps)
cmd += ' --eigen-pco-downsample {}'.format(args.advanced_ffc_eigen_pco_downsample)
cmd += ' --downsample {}'.format(args.advanced_ffc_downsample)
return cmd
def get_pr_tofu_cmd_sinFFC(self, ctset, args, nviews, WH):
# indir will format paths to flats darks and tomo2 correctly even if they were
# pre-processed, however path to the input directory with projections
# cannot be formatted with that command correctly
# indir = self.make_inpaths(ctset[0], ctset[1])
# so we need a separate "universal" command which considers all previous steps
in_proj_dir, out_pattern = fmt_in_out_path(args.main_config_temp_dir,
ctset[0], self._fdt_names[2])
# Phase retrieval
cmd = 'tofu preprocess --delta 1e-6'
cmd += ' --energy {} --propagation-distance {}' \
' --pixel-size {} --regularization-rate {:0.2f}' \
.format(args.main_pr_photon_energy, args.main_pr_detector_distance,
args.main_pr_pixel_size, args.main_pr_delta_beta_ratio)
cmd += ' --projections {}'.format(in_proj_dir)
cmd += ' --output {}'.format(out_pattern)
cmd += ' --projection-crop-after filter'
return cmd
def get_pr_tofu_cmd(self, ctset, args, nviews, WH):
# indir will format paths to flats darks and tomo2 correctly even if they were
# pre-processed, however path to the input directory with projections
# cannot be formatted with that command correctly
indir = self.make_inpaths(ctset[0], ctset[1], args)
# so we need a separate "universal" command which considers all previous steps
in_proj_dir, out_pattern = fmt_in_out_path(args.main_config_temp_dir,
ctset[0], self._fdt_names[2])
cmd = 'tofu preprocess --fix-nan-and-inf --projection-filter none --delta 1e-6'
cmd += ' --darks {} --flats {} --projections {}'.format(indir[0], indir[1], in_proj_dir)
if ctset[1] == 4:
cmd += ' --flats2 {}'.format(indir[3])
cmd += ' --output {}'.format(out_pattern)
cmd += ' --energy {} --propagation-distance {}' \
' --pixel-size {} --regularization-rate {:0.2f}' \
.format(args.main_pr_photon_energy, args.main_pr_detector_distance,
args.main_pr_pixel_size, args.main_pr_delta_beta_ratio)
if not args.advanced_advtofu_aux_ffc_dark_scale == "":
cmd += ' --dark-scale {}'.format(args.advanced_advtofu_aux_ffc_dark_scale)
if not args.advanced_advtofu_aux_ffc_flat_scale == "":
cmd += ' --flat-scale {}'.format(args.advanced_advtofu_aux_ffc_flat_scale)
return cmd
def get_reco_cmd(self, ctset, out_pattern, ax, args, nviews, WH, ffc, PR):
# direct CT reconstruction from input dir to output dir;
# or CT reconstruction after preprocessing only
indir = self.make_inpaths(ctset[0], ctset[1], args)
# correct location of proj folder in case if prepro was done
in_proj_dir, quatsch = fmt_in_out_path(args.main_config_temp_dir,
ctset[0], self._fdt_names[2], False)
# Laminography
cmd = ''
if args.advanced_advtofu_extended_settings is True:
cmd += self.check_lamino(cmd, args)
elif args.advanced_advtofu_extended_settings is False:
cmd = "tofu reco"
cmd += ' --overall-angle 180'
##############
cmd += ' --projections {}'.format(in_proj_dir)
cmd += ' --output {}'.format(out_pattern)
if ffc:
cmd += ' --fix-nan-and-inf'
cmd += ' --darks {} --flats {}'.format(indir[0], indir[1])
if ctset[1] == 4: # must be equivalent to len(indir)>3
cmd += ' --flats2 {}'.format(indir[3])
if not PR:
cmd += ' --absorptivity'
if not args.advanced_advtofu_aux_ffc_dark_scale == "":
cmd += ' --dark-scale {}'.format(args.advanced_advtofu_aux_ffc_dark_scale)
if not args.advanced_advtofu_aux_ffc_flat_scale == "":
cmd += ' --flat-scale {}'.format(args.advanced_advtofu_aux_ffc_flat_scale)
if PR:
cmd += (
" --disable-projection-crop"
" --delta 1e-6"
" --energy {} --propagation-distance {}"
" --pixel-size {} --regularization-rate {:0.2f}" \
.format(args.main_pr_photon_energy, args.main_pr_detector_distance,
args.main_pr_pixel_size, args.main_pr_delta_beta_ratio)
)
cmd += " --center-position-x {}".format(ax)
# if args.nviews==0:
cmd += " --number {}".format(nviews)
# elif args.nviews>0:
# cmd += ' --number {}'.format(args.nviews)
cmd += ' --volume-angle-z {:0.5f}'.format(args.main_region_rotate_volume_clock)
# rows-slices to be reconstructed
# full ROI
b = int(np.ceil(WH[0] / 2.0))
a = -int(WH[0] / 2.0)
c = 1
if args.main_region_select_rows:
if args.main_filters_ring_removal:
h2 = args.main_region_number_rows / args.main_region_nth_row / 2.0
b = np.ceil(h2)
a = -int(h2)
else:
h2 = int(WH[0] / 2.0)
a = args.main_region_first_row - h2
b = args.main_region_first_row + args.main_region_number_rows - h2
c = args.main_region_nth_row
cmd += ' --region={},{},{}'.format(a, b, c)
# crop of reconstructed slice in the axial plane
b = WH[1] / 2
if args.main_region_crop_slices:
cmd += ' --x-region={},{},{}'.format(args.main_region_crop_x - b,
args.main_region_crop_x + args.main_region_crop_width - b, 1)
cmd += ' --y-region={},{},{}'.format(args.main_region_crop_y - b,
args.main_region_crop_y + args.main_region_crop_height - b, 1)
# cmd = self.check_vcrop(cmd, args.main_region_select_rows, args.main_region_first_row, args.main_region_number_rows, args.main_region_nth_row, WH[0])
cmd = self.check_8bit(cmd, args.main_region_clip_histogram,
args.main_region_bit_depth,
args.main_region_histogram_min,
args.main_region_histogram_max)
cmd = self.check_bigtif(cmd, args.main_config_save_multipage_tiff)
# Optimization
cmd += ' --slice-memory-coeff={}'.format(args.advanced_optimize_slice_mem_coeff)
if args.advanced_optimize_verbose_console:
cmd += ' --verbose'
if not args.advanced_optimize_num_gpus == '':
cmd += ' --gpus {}'.format(args.advanced_optimize_num_gpus)
if not args.advanced_optimize_slices_per_device == '':
cmd += ' --slices-per-device {}'.format(args.advanced_optimize_slices_per_device)
return cmd
def get_reco_cmd_sinFFC(self, ctset, out_pattern, ax, args, nviews, WH, ffc, PR):
# direct CT reconstruction from input dir to output dir;
# or CT reconstruction after preprocessing only
indir = self.make_inpaths(ctset[0], ctset[1], args)
# correct location of proj folder in case if prepro was done
in_proj_dir, quatsch = fmt_in_out_path(args.main_config_temp_dir,
ctset[0], self._fdt_names[2], False)
# in_proj_dir, quatsch = fmt_in_out_path(args.tmpdir,args.indir, self._fdt_names[2], False)
# indir[2]=os.path.join(os.path.split(indir[2])[0], os.path.split(in_proj_dir)[1])
# format command
cmd = "tofu reco"
# Laminography
if args.advanced_advtofu_extended_settings:
cmd += self.check_lamino(cmd, args)
else:
cmd += " --overall-angle 180"
##############
cmd += " --projections {}".format(in_proj_dir)
cmd += " --output {}".format(out_pattern)
if PR:
cmd += ' --disable-projection-crop' \
' --delta 1e-6' \
' --energy {} --propagation-distance {}' \
' --pixel-size {} --regularization-rate {:0.2f}' \
.format(args.main_pr_photon_energy, args.main_pr_detector_distance,
args.main_pr_pixel_size, args.main_pr_delta_beta_ratio)
cmd += ' --center-position-x {}'.format(ax)
# if args.nviews==0:
cmd += " --number {}".format(nviews)
# elif args.nviews>0:
# cmd += ' --number {}'.format(args.nviews)
cmd += " --volume-angle-z {:0.5f}".format(args.main_region_rotate_volume_clock)
# rows-slices to be reconstructed
# full ROI
b = int(np.ceil(WH[0] / 2.0))
a = -int(WH[0] / 2.0)
c = 1
if args.main_region_select_rows:
if args.main_filters_ring_removal:
h2 = args.main_region_number_rows / args.main_region_nth_row / 2.0
b = np.ceil(h2)
a = -int(h2)
else:
h2 = int(WH[0] / 2.0)
a = args.main_region_first_row - h2
b = args.main_region_first_row + args.main_region_number_rows - h2
c = args.main_region_nth_row
cmd += ' --region={},{},{}'.format(a, b, c)
# crop of reconstructed slice in the axial plane
b = WH[1] / 2
if args.main_region_crop_slices:
cmd += ' --x-region={},{},{}'.format(args.main_region_crop_x - b,
args.main_region_crop_x + args.main_region_crop_width - b, 1)
cmd += ' --y-region={},{},{}'.format(args.main_region_crop_y - b,
args.main_region_crop_y + args.main_region_crop_height - b, 1)
# cmd = self.check_vcrop(cmd, args.main_region_select_rows, args.main_region_first_row, args.main_region_number_rows, args.main_region_nth_row, WH[0])
cmd = self.check_8bit(cmd, args.main_region_clip_histogram,
args.main_region_bit_depth,
args.main_region_histogram_min,
args.main_region_histogram_max)
cmd = self.check_bigtif(cmd, args.main_config_save_multipage_tiff)
# Optimization
cmd += ' --slice-memory-coeff={}'.format(args.advanced_optimize_slice_mem_coeff)
if args.advanced_optimize_verbose_console:
cmd += ' --verbose'
if not args.advanced_optimize_num_gpus == '':
cmd += ' --gpus {}'.format(args.advanced_optimize_num_gpus)
if not args.advanced_optimize_slices_per_device == '':
cmd += ' --slices-per-device {}'.format(args.advanced_optimize_slices_per_device)
return cmd
tofu-0.12.0/tofu/ez/ufo_cmd_gen.py 0000664 0000000 0000000 00000025516 14237137211 0016773 0 ustar 00root root 0000000 0000000 #!/bin/python
"""
Created on Apr 6, 2018
@author: gasilos
"""
import glob
import os
from tofu.util import get_filenames, read_image, next_power_of_two
from tofu.ez.util import enquote
def fmt_in_out_path(tmpdir, indir, raw_proj_dir_name, croutdir=True):
# suggests input and output path to directory with proj
# depending on number of processing steps applied so far
li = sorted(glob.glob(os.path.join(tmpdir, "proj-step*")))
proj_dirs = [d for d in li if os.path.isdir(d)]
Nsteps = len(proj_dirs)
in_proj_dir, out_proj_dir = "qqq", "qqq"
if Nsteps == 0: # no projections in temporary directory
in_proj_dir = os.path.join(indir, raw_proj_dir_name)
out_proj_dir = "proj-step1"
elif Nsteps > 0: # there are directories proj-stepX in tmp dir
in_proj_dir = proj_dirs[-1]
out_proj_dir = "{}{}".format(in_proj_dir[:-1], Nsteps + 1)
else:
raise ValueError("Something is wrong with in/out filenames")
# physically create output directory
tmp = os.path.join(tmpdir, out_proj_dir)
if croutdir and not os.path.exists(tmp):
os.makedirs(tmp)
# return names of input directory and output pattern with abs path
return in_proj_dir, os.path.join(tmp, "proj-%04i.tif")
class ufo_cmds(object):
"""
Generates partially formatted ufo-launch and tofu commands
Parameters are included in the string; pathnames must be added
"""
def __init__(self, fol):
self._fdt_names = fol
def make_inpaths(self, lvl0, flats2, args):
"""
Creates a list of paths to flats/darks/tomo directories
:param lvl0: Root of directory containing flats/darks/tomo
:param flats2: The type of directory: 3 contains flats/darks/tomo 4 contains flats/darks/tomo/flats2
:return: List of paths to the directories containing darks/flats/tomo and flats2 (if used)
"""
indir = []
# If using flats/darks/flats2 in same dir as tomo
if not args.main_config_common_flats_darks:
for i in self._fdt_names[:3]:
indir.append(os.path.join(lvl0, i))
if flats2 - 3:
indir.append(os.path.join(lvl0, self._fdt_names[3]))
return indir
# If using common flats/darks/flats2 across multiple reconstructions
elif args.main_config_common_flats_darks:
indir.append(args.main_config_darks_path)
indir.append(args.main_config_flats_path)
indir.append(os.path.join(lvl0, self._fdt_names[2]))
if args.main_config_flats2_checkbox:
indir.append(args.main_config_flats2_path)
return indir
def check_vcrop(self, cmd, vcrop, y, yheight, ystep):
if vcrop:
cmd += " --y {} --height {} --y-step {}".format(y, yheight, ystep)
return cmd
def check_bigtif(self, cmd, swi):
if not swi:
cmd += " bytes-per-file=0"
return cmd
def get_pr_ufo_cmd(self, args, nviews, WH):
# in_proj_dir, out_pattern = fmt_in_out_path(args.main_config_temp_dir,args.main_config_input_dir,self._fdt_names[2])
in_proj_dir, out_pattern = fmt_in_out_path(args.main_config_temp_dir,
"quatsch", self._fdt_names[2])
cmds = []
pad_width = next_power_of_two(WH[1] + 50)
pad_height = next_power_of_two(WH[0] + 50)
pad_x = (pad_width - WH[1]) / 2
pad_y = (pad_height - WH[0]) / 2
cmd = 'ufo-launch read path={} height={} number={}'.format(in_proj_dir, WH[0], nviews)
cmd += ' ! pad x={} width={} y={} height={}'.format(pad_x, pad_width, pad_y, pad_height)
cmd += ' addressing-mode=clamp_to_edge'
cmd += ' ! fft dimensions=2 ! retrieve-phase'
cmd += ' energy={} distance={} pixel-size={} regularization-rate={:0.2f}' \
.format(args.main_pr_photon_energy, args.main_pr_detector_distance,
args.main_pr_pixel_size, args.main_pr_delta_beta_ratio)
cmd += ' ! ifft dimensions=2 crop-width={} crop-height={}' \
.format(pad_width, pad_height)
cmd += ' ! crop x={} width={} y={} height={}'.format(pad_x, WH[1], pad_y, WH[0])
cmd += ' ! opencl kernel=\'absorptivity\' ! opencl kernel=\'fix_nan_and_inf\' !'
cmd += ' write filename={}'.format(enquote(out_pattern))
cmds.append(cmd)
if not args.main_config_keep_temp:
cmds.append('rm -rf {}'.format(in_proj_dir))
return cmds
def get_filter1d_sinos_cmd(self, tmpdir, RR, nviews):
sin_in = os.path.join(tmpdir, 'sinos')
out_pattern = os.path.join(tmpdir, 'sinos-filt/sin-%04i.tif')
pad_height = next_power_of_two(nviews + 500)
pad_y = (pad_height - nviews) / 2
cmd = 'ufo-launch read path={}'.format(sin_in)
cmd += ' ! pad y={} height={}'.format(pad_y, pad_height)
cmd += ' addressing-mode=clamp_to_edge'
cmd += ' ! transpose ! fft dimensions=1 ! filter-stripes1d strength={}'.format(RR)
cmd += ' ! ifft dimensions=1 ! transpose'
cmd += ' ! crop y={} height={}'.format(pad_y, nviews)
cmd += ' ! write filename={}'.format(enquote(out_pattern))
return cmd
def get_filter2d_sinos_cmd(self, tmpdir, sig_hor, sig_ver, nviews, w):
sin_in = os.path.join(tmpdir, "sinos")
out_pattern = os.path.join(tmpdir, "sinos-filt/sin-%04i.tif")
pad_height = next_power_of_two(nviews + 500)
pad_y = (pad_height - nviews) / 2
pad_width = next_power_of_two(w + 500)
pad_x = (pad_width - w) / 2
cmd = "ufo-launch read path={}".format(sin_in)
cmd += " ! pad x={} width={} y={} height={}".format(pad_x, pad_width, pad_y, pad_height)
cmd += " addressing-mode=mirrored_repeat"
cmd += " ! fft dimensions=2 ! filter-stripes horizontal-sigma={} vertical-sigma={}".format(
sig_hor, sig_ver
)
cmd += " ! ifft dimensions=2 crop-width={} crop-height={}".format(pad_width, pad_height)
cmd += " ! crop x={} width={} y={} height={}".format(pad_x, w, pad_y, nviews)
cmd += " ! write filename={}".format(enquote(out_pattern))
return cmd
def get_pre_cmd(self, ctset, pre_cmd, tmpdir, dryrun, args):
indir = self.make_inpaths(ctset[0], ctset[1], args)
outdir = self.make_inpaths(tmpdir, ctset[1], args)
# add index to the name of the output directory with projections
# if enabled preprocessing is always the first step
outdir[2] = os.path.join(tmpdir, "proj-step1")
# we also must create this directory to format paths correctly
if not os.path.exists(outdir[2]):
os.makedirs(outdir[2])
cmds = []
for i, fol in enumerate(indir):
in_pattern = os.path.join(fol, "*.tif")
out_pattern = os.path.join(outdir[i], "frame-%04i.tif")
cmds.append("ufo-launch")
cmds[i] += " read path={} ! ".format(enquote(in_pattern))
cmds[i] += pre_cmd
cmds[i] += " ! write filename={}".format(enquote(out_pattern))
return cmds
def get_inp_cmd(self, ctset, tmpdir, args, N, nviews, any_flat):
indir = self.make_inpaths(ctset[0], ctset[1], args)
outdir = self.make_inpaths(tmpdir, ctset[1], args)
cmds = []
######### CREATE MASK #########
mask_file = os.path.join(tmpdir, "mask.tif")
# generate mask
cmd = 'tofu find-large-spots --images {}'.format(any_flat)
cmd += ' --spot-threshold {} --gauss-sigma {}'.format(
args.main_filters_remove_spots_threshold,
args.main_filters_remove_spots_blur_sigma)
cmd += ' --output {} --output-bytes-per-file 0'.format(mask_file)
cmds.append(cmd)
######### FLAT-CORRECT #########
in_proj_dir, out_pattern = fmt_in_out_path(args.main_config_temp_dir, ctset[0], self._fdt_names[2])
if args.advanced_ffc_sinFFC:
cmd = 'bmit_sin --fix-nan'
cmd += ' --darks {} --flats {}'.format(indir[0], indir[1])
cmd += ' --projections {}'.format(in_proj_dir)
cmd += ' --output {}'.format(os.path.dirname(out_pattern))
cmd += ' --multiprocessing'
#cmd += ' --output {}'.format(out_pattern)
if ctset[1] == 4:
cmd += ' --flats2 {}'.format(indir[3])
# Add options for eigen-pco-repetitions etc.
cmd += ' --eigen-pco-repetitions {}'.format(args.advanced_ffc_eigen_pco_reps)
cmd += ' --eigen-pco-downsample {}'.format(args.advanced_ffc_eigen_pco_downsample)
cmd += ' --downsample {}'.format(args.advanced_ffc_downsample)
#if not args.main_pr_phase_retrieval:
# cmd += ' --absorptivity'
cmds.append(cmd)
elif not args.advanced_ffc_sinFFC:
cmd = 'tofu flatcorrect --fix-nan-and-inf'
cmd += ' --darks {} --flats {}'.format(indir[0], indir[1])
cmd += ' --projections {}'.format(in_proj_dir)
cmd += ' --output {}'.format(out_pattern)
if ctset[1] == 4:
cmd += ' --flats2 {}'.format(indir[3])
if not args.main_pr_phase_retrieval:
cmd += ' --absorptivity'
if not args.advanced_advtofu_aux_ffc_dark_scale == "":
cmd += ' --dark-scale {}'.format(args.advanced_advtofu_aux_ffc_dark_scale)
if not args.advanced_advtofu_aux_ffc_flat_scale == "":
cmd += ' --flat-scale {}'.format(args.advanced_advtofu_aux_ffc_flat_scale)
cmds.append(cmd)
if not args.main_config_keep_temp and args.main_config_preprocess:
cmds.append('rm -rf {}'.format(indir[0]))
cmds.append('rm -rf {}'.format(indir[1]))
cmds.append('rm -rf {}'.format(in_proj_dir))
if len(indir) > 3:
cmds.append("rm -rf {}".format(indir[3]))
######### INPAINT #########
in_proj_dir, out_pattern = fmt_in_out_path(args.main_config_temp_dir, ctset[0], self._fdt_names[2])
cmd = "ufo-launch [read path={} height={} number={}".format(in_proj_dir, N, nviews)
cmd += ", read path={}]".format(mask_file)
cmd += " ! horizontal-interpolate ! "
cmd += "write filename={}".format(enquote(out_pattern))
cmds.append(cmd)
if not args.main_config_keep_temp:
cmds.append("rm -rf {}".format(in_proj_dir))
return cmds
def get_crop_sli(self, out_pattern, args):
cmd = 'ufo-launch read path={}/*.tif ! '.format(os.path.dirname(out_pattern))
cmd += 'crop x={} width={} y={} height={} ! '. \
format(args.main_region_crop_x, args.main_region_crop_width,
args.main_region_crop_y, args.main_region_crop_height)
cmd += 'write filename={}'.format(out_pattern)
if args.main_region_clip_histogram:
cmd += ' bits=8 rescale=False'
return cmd
tofu-0.12.0/tofu/ez/util.py 0000664 0000000 0000000 00000017530 14237137211 0015500 0 ustar 00root root 0000000 0000000 """
Created on Apr 20, 2020
@author: gasilos
"""
import os
import tifffile
import yaml
import numpy as np
from tofu.util import get_filenames, get_first_filename, get_image_shape, read_image
import tofu.ez.params as parameters
def get_dims(pth):
# get number of projections and projections dimensions
first_proj = get_first_filename(pth)
multipage = False
try:
shape = get_image_shape(first_proj)
except:
raise ValueError("Failed to determine size and number of projections in {}".format(pth))
if len(shape) == 2: # single page input
return len(get_filenames(pth)), [shape[-2], shape[-1]], multipage
elif len(shape) == 3: # multipage input
nviews = 0
for i in get_filenames(pth):
nviews += get_image_shape(i)[0]
multipage = True
return nviews, [shape[-2], shape[-1]], multipage
return -6, [-6, -6]
def bad_vert_ROI(multipage, path2proj, y, height):
if multipage:
with tifffile.TiffFile(get_filenames(path2proj)[0]) as tif:
proj = tif.pages[0].asarray().astype(np.float)
else:
proj = read_image(get_filenames(path2proj)[0]).astype(np.float)
y_region = slice(y, min(y + height, proj.shape[0]), 1)
proj = proj[y_region, :]
if proj.shape[0] == 0:
return True
else:
return False
def make_copy_of_flat(flatdir, flat_copy_name, dryrun):
first_flat_file = get_first_filename(flatdir)
try:
shape = get_image_shape(first_flat_file)
except:
raise ValueError("Failed to determine size and number of flats in {}".format(flatdir))
cmd = ""
if len(shape) == 2:
last_flat_file = get_filenames(flatdir)[-1]
cmd = "cp {} {}".format(last_flat_file, flat_copy_name)
else:
flat = read_image(get_filenames(flatdir)[-1])[-1]
if dryrun:
cmd = 'echo Will save a copy of flat into "{}"'.format(flat_copy_name)
else:
tifffile.imsave(flat_copy_name, flat)
return cmd
def clean_tmp_dirs(tmpdir, fdt_names):
tmp_pattern = ["proj", "sino", "mask", "flat", "dark", "radi"]
tmp_pattern += fdt_names
# clean directories in tmpdir if their names match pattern
if os.path.exists(tmpdir):
for filename in os.listdir(tmpdir):
if filename[:4] in tmp_pattern:
os.system("rm -rf {}".format(os.path.join(tmpdir, filename)))
def enquote(string, escape=False):
addition = '\\"' if escape else '"'
return addition + string + addition
def save_params(args, ctsetname, ax, nviews, WH):
if not args.main_config_dry_run and not os.path.exists(args.main_config_output_dir):
os.makedirs(args.main_config_output_dir)
tmp = os.path.join(args.main_config_output_dir, ctsetname)
if not args.main_config_dry_run and not os.path.exists(tmp):
os.makedirs(tmp)
if not args.main_config_dry_run and args.main_config_save_params:
# Dump the params .yaml file
try:
yaml_output_filepath = os.path.join(tmp, "parameters.yaml")
yaml_output = open(yaml_output_filepath, "w")
yaml.dump(parameters.params, yaml_output)
except FileNotFoundError:
print("Something went wrong when exporting the .yaml parameters file")
# Dump the reco.params output file
fname = os.path.join(tmp, 'reco.params')
f = open(fname, 'w')
f.write('*** General ***\n')
f.write('Input directory {}\n'.format(args.main_config_input_dir))
if ctsetname == '':
ctsetname = '.'
f.write('CT set {}\n'.format(ctsetname))
if args.main_cor_axis_search_method == 1 or args.main_cor_axis_search_method == 2:
f.write('Center of rotation {} (auto estimate)\n'.format(ax))
else:
f.write('Center of rotation {} (user defined)\n'.format(ax))
f.write('Dimensions of projections {} x {} (height x width)\n'.format(WH[0], WH[1]))
f.write('Number of projections {}\n'.format(nviews))
f.write('*** Preprocessing ***\n')
tmp = 'None'
if args.main_config_preprocess:
tmp = args.main_config_preprocess_command
f.write(' '+tmp+'\n')
f.write('*** Image filters ***\n')
if args.main_filters_remove_spots:
f.write(' Remove large spots enabled\n')
f.write(' threshold {}\n'.format(args.main_filters_remove_spots_threshold))
f.write(' sigma {}\n'.format(args.main_filters_remove_spots_blur_sigma))
else:
f.write(' Remove large spots disabled\n')
if args.main_pr_phase_retrieval:
f.write(' Phase retrieval enabled\n')
f.write(' energy {} keV\n'.format(args.main_pr_photon_energy))
f.write(' pixel size {:0.1f} um\n'.format(args.main_pr_pixel_size * 1e6))
f.write(' sample-detector distance {} m\n'.format(args.main_pr_detector_distance))
f.write(' delta/beta ratio {:0.0f}\n'.format(10 ** args.main_pr_delta_beta_ratio))
else:
f.write(' Phase retrieval disabled\n')
f.write('*** Ring removal ***\n')
if args.main_filters_ring_removal:
if args.main_filters_ring_removal_ufo_lpf:
tmp = '2D'
if args.main_filters_ring_removal_ufo_lpf_1d_or_2d:
tmp = '1D'
f.write(' RR with ufo {} stripes filter\n'.format(tmp))
f.write(f' sigma horizontal {args.main_filters_ring_removal_ufo_lpf_sigma_horizontal}')
f.write(f' sigma vertical {args.main_filters_ring_removal_ufo_lpf_sigma_vertical}')
else:
if args.main_filters_ring_removal_sarepy_wide:
tmp = ' RR with ufo sarepy remove wide filter, '
tmp += 'window {}, SNR {}\n'.format(
args.main_filters_ring_removal_sarepy_window,
args.main_filters_ring_removal_sarepy_SNR)
f.write(tmp)
f.write(' '
'RR with ufo sarepy sorting filter, window {}\n'.
format(args.main_filters_ring_removal_sarepy_window_size)
)
else:
f.write('RR disabled\n')
f.write('*** Region of interest ***\n')
if args.main_region_select_rows:
f.write('Vertical ROI defined\n')
f.write(' first row {}\n'.format(args.main_region_first_row))
f.write(' height {}\n'.format(args.main_region_number_rows))
f.write(' reconstruct every {}th row\n'.format(args.main_region_nth_row))
else:
f.write('Vertical ROI: all rows\n')
if args.main_region_crop_slices:
f.write('ROI in slice plane defined\n')
f.write(' x {}\n'.format(args.main_region_crop_x))
f.write(' width {}\n'.format(args.main_region_crop_width))
f.write(' y {}\n'.format(args.main_region_crop_y))
f.write(' height {}\n'.format(args.main_region_crop_height))
else:
f.write('ROI in slice plane not defined\n')
f.write('*** Reconstructed values ***\n')
if args.main_region_clip_histogram:
f.write(' {} bit\n'.format(args.main_region_bit_depth))
f.write(' Min value in 32-bit histogram {}\n'.format(args.main_region_histogram_min))
f.write(' Max value in 32-bit histogram {}\n'.format(args.main_region_histogram_max))
else:
f.write(' 32bit, histogram untouched\n')
f.write('*** Optional reco parameters ***\n')
if args.main_region_rotate_volume_clock > 0:
f.write(' Rotate volume by: {:0.3f} deg\n'.format(args.main_region_rotate_volume_clock))
f.close()
tofu-0.12.0/tofu/ez/yaml_in_out.py 0000664 0000000 0000000 00000000772 14237137211 0017042 0 ustar 00root root 0000000 0000000 import yaml
import logging
LOG = logging.getLogger(__name__)
class Yaml_IO:
def read_yaml(self, filePath):
with open(filePath) as f:
data = yaml.load(f, Loader=yaml.FullLoader)
LOG.debug(data)
return data
def write_yaml(self, filePath, params):
try:
file = open(filePath, "w")
except FileNotFoundError:
LOG.debug("No filename given")
else:
yaml.dump(params, file)
file.close()
tofu-0.12.0/tofu/find_large_spots.py 0000664 0000000 0000000 00000004331 14237137211 0017422 0 ustar 00root root 0000000 0000000 import logging
from gi.repository import Ufo
from tofu.util import set_node_props, determine_shape, setup_read_task, setup_padding
from tofu.tasks import get_task, get_writer
LOG = logging.getLogger(__name__)
def find_large_spots(args):
graph = Ufo.TaskGraph()
sched = Ufo.FixedScheduler()
reader = get_task('read')
writer = get_writer(args)
if args.gauss_sigma and args.blurred_output:
broadcast = Ufo.CopyTask()
blurred_writer = get_task('write')
if hasattr(blurred_writer.props, 'bytes_per_file'):
blurred_writer.props.bytes_per_file = 0
if hasattr(blurred_writer.props, 'tiff_bigtiff'):
blurred_writer.props.tiff_bigtiff = False
blurred_writer.props.filename = args.blurred_output
find = get_task('find-large-spots')
set_node_props(find, args)
find.props.addressing_mode = args.find_large_spots_padding_mode
set_node_props(reader, args)
setup_read_task(reader, args.images, args)
if args.gauss_sigma:
reader_2 = get_task('read')
set_node_props(reader_2, args)
setup_read_task(reader_2, args.images, args)
pad = get_task('pad')
crop = get_task('crop')
opencl = get_task('opencl', kernel='diff', filename='opencl.cl')
width, height = determine_shape(args, path=args.images)
gauss_size = int(10 * args.gauss_sigma)
setup_padding(pad, width, height, args.find_large_spots_padding_mode,
crop=crop, pad_width=gauss_size, pad_height=gauss_size)
LOG.debug("Gauss size: %d", gauss_size)
blur = get_task('blur', sigma=args.gauss_sigma, size=gauss_size)
graph.connect_nodes_full(reader, opencl, 0)
graph.connect_nodes(reader_2, pad)
graph.connect_nodes(pad, blur)
graph.connect_nodes(blur, crop)
graph.connect_nodes_full(crop, opencl, 1)
if args.blurred_output:
graph.connect_nodes(opencl, broadcast)
graph.connect_nodes(broadcast, blurred_writer)
source = broadcast
else:
source = opencl
graph.connect_nodes(source, find)
else:
graph.connect_nodes(reader, find)
graph.connect_nodes(find, writer)
sched.run(graph)
tofu-0.12.0/tofu/flow/ 0000775 0000000 0000000 00000000000 14237137211 0014474 5 ustar 00root root 0000000 0000000 tofu-0.12.0/tofu/flow/__init__.py 0000664 0000000 0000000 00000000000 14237137211 0016573 0 ustar 00root root 0000000 0000000 tofu-0.12.0/tofu/flow/composites/ 0000775 0000000 0000000 00000000000 14237137211 0016661 5 ustar 00root root 0000000 0000000 tofu-0.12.0/tofu/flow/composites/ffc-links.cm 0000664 0000000 0000000 00000022102 14237137211 0021053 0 ustar 00root root 0000000 0000000 {
"name": "CFlatFieldCorrect",
"caption": "CFlatFieldCorrect",
"models": {
"Flat Field Correct": {
"model": {
"caption": "Flat Field Correct",
"properties": {
"fix-nan-and-inf": [
true,
true
],
"absorption-correct": [
true,
true
],
"sinogram-input": [
false,
false
],
"dark-scale": [
1.0,
false
],
"flat-scale": [
1.0,
false
]
}
},
"visible": true,
"position": {
"x": 1253.0,
"y": 490.0
},
"name": "flat_field_correct"
},
"Read 2": {
"model": {
"caption": "Read 2",
"properties": {
"path": [
".",
true
],
"start": [
0,
false
],
"number": [
4294967295,
true
],
"step": [
1,
false
],
"y": [
0,
false
],
"height": [
0,
false
],
"y-step": [
1,
false
],
"convert": [
true,
false
],
"raw-width": [
0,
false
],
"raw-height": [
0,
false
],
"raw-bitdepth": [
0,
false
],
"raw-pre-offset": [
0,
false
],
"raw-post-offset": [
0,
false
],
"type": [
"unspecified",
false
],
"retries": [
0,
false
],
"retry-timeout": [
1,
false
]
}
},
"visible": true,
"position": {
"x": 417.0,
"y": 504.0
},
"name": "read"
},
"Average": {
"model": {
"caption": "Average",
"properties": {
"number": [
4294967295,
true
]
}
},
"visible": true,
"position": {
"x": 822.0,
"y": 508.0
},
"name": "average"
},
"Read 3": {
"model": {
"caption": "Read 3",
"properties": {
"path": [
".",
true
],
"start": [
0,
false
],
"number": [
4294967295,
true
],
"step": [
1,
false
],
"y": [
0,
false
],
"height": [
0,
false
],
"y-step": [
1,
false
],
"convert": [
true,
false
],
"raw-width": [
0,
false
],
"raw-height": [
0,
false
],
"raw-bitdepth": [
0,
false
],
"raw-pre-offset": [
0,
false
],
"raw-post-offset": [
0,
false
],
"type": [
"unspecified",
false
],
"retries": [
0,
false
],
"retry-timeout": [
1,
false
]
}
},
"visible": true,
"position": {
"x": 413.0,
"y": 735.0
},
"name": "read"
},
"Average 2": {
"model": {
"caption": "Average 2",
"properties": {
"number": [
4294967295,
true
]
}
},
"visible": true,
"position": {
"x": 822.0,
"y": 741.0
},
"name": "average"
},
"Read": {
"model": {
"caption": "Read",
"properties": {
"path": [
".",
true
],
"start": [
0,
false
],
"number": [
23212,
true
],
"step": [
1,
false
],
"y": [
0,
false
],
"height": [
0,
false
],
"y-step": [
1,
false
],
"convert": [
true,
false
],
"raw-width": [
0,
false
],
"raw-height": [
0,
false
],
"raw-bitdepth": [
0,
false
],
"raw-pre-offset": [
0,
false
],
"raw-post-offset": [
0,
false
],
"type": [
"unspecified",
false
],
"retries": [
0,
false
],
"retry-timeout": [
1,
false
]
}
},
"visible": true,
"position": {
"x": 418.0,
"y": 245.0
},
"name": "read"
}
},
"connections": [
[
"Read",
0,
"Flat Field Correct",
0
],
[
"Average",
0,
"Flat Field Correct",
1
],
[
"Average 2",
0,
"Flat Field Correct",
2
],
[
"Read 2",
0,
"Average",
0
],
[
"Read 3",
0,
"Average 2",
0
]
],
"links": [
[
[
"Read 2",
"number"
],
[
"Average",
"number"
]
],
[
[
"Read 3",
"number"
],
[
"Average 2",
"number"
]
]
]
}
tofu-0.12.0/tofu/flow/composites/pr.cm 0000664 0000000 0000000 00000011516 14237137211 0017627 0 ustar 00root root 0000000 0000000 {
"name": "CPhaseRetrieve",
"caption": "CPhaseRetrieve",
"models": {
"Fft": {
"model": {
"caption": "Fft",
"properties": {
"auto-zeropadding": [
true,
true
],
"dimensions": [
2,
true
],
"size-x": [
1,
true
],
"size-y": [
1,
true
],
"size-z": [
1,
true
]
}
},
"visible": true,
"position": {
"x": 112.0,
"y": 245.0
},
"name": "fft"
},
"Ifft": {
"model": {
"caption": "Ifft",
"properties": {
"dimensions": [
2,
true
],
"crop-width": [
-1,
true
],
"crop-height": [
-1,
true
]
}
},
"visible": true,
"position": {
"x": 772.0,
"y": 250.0
},
"name": "ifft"
},
"Retrieve Phase": {
"model": {
"caption": "Retrieve Phase",
"num-inputs": 1,
"properties": {
"method": [
"tie",
true
],
"energy": [
20.0,
true
],
"distance": [
0.0,
true
],
"distance-x": [
0.0,
true
],
"distance-y": [
0.0,
true
],
"pixel-size": [
7.500000265281415e-07,
true
],
"regularization-rate": [
2.5,
true
],
"thresholding-rate": [
0.10000000149011612,
true
],
"frequency-cutoff": [
3.4028234663852886e+38,
true
],
"output-filter": [
false,
true
]
}
},
"visible": true,
"position": {
"x": 544.0,
"y": 515.0
},
"name": "retrieve_phase"
},
"Pad": {
"model": {
"caption": "Pad",
"properties": {
"width": [
0,
true
],
"height": [
0,
true
],
"x": [
0,
true
],
"y": [
0,
true
],
"addressing-mode": [
"clamp",
true
]
}
},
"visible": true,
"position": {
"x": 0.0,
"y": 570.0
},
"name": "pad"
}
},
"connections": [
[
"Pad",
0,
"Fft",
0
],
[
"Fft",
0,
"Retrieve Phase",
0
],
[
"Retrieve Phase",
0,
"Ifft",
0
]
],
"links": [
[
[
"Fft",
"dimensions"
],
[
"Ifft",
"dimensions"
]
],
[
[
"Fft",
"size-x"
],
[
"Pad",
"width"
]
],
[
[
"Fft",
"size-y"
],
[
"Pad",
"height"
]
]
]
}
tofu-0.12.0/tofu/flow/config.json 0000664 0000000 0000000 00000006771 14237137211 0016647 0 ustar 00root root 0000000 0000000 {
"models": {
"average": {
"hidden-properties": [
"number"
]
},
"flat-field-correct": {
"port-captions": {
"input": {
"0": "radios",
"1": "darks",
"2": "flats"
},
"output": {
"0": ""
}
},
"hidden-properties": [
"sinogram-input",
"dark-scale",
"flat-scale"
]
},
"general-backproject": {
"hidden-properties": [
"z",
"burst",
"source-position-x",
"source-position-y",
"source-position-z",
"detector-position-x",
"detector-position-y",
"detector-position-z",
"detector-angle-x",
"detector-angle-y",
"detector-angle-z",
"axis-angle-x",
"axis-angle-y",
"axis-angle-z",
"volume-angle-x",
"volume-angle-y",
"volume-angle-z",
"compute-type",
"result-type",
"store-type",
"addressing-mode",
"gray-map-min",
"gray-map-max"
],
"range-properties": {
"region": [3, true],
"x-region": [3, true],
"y-region": [3, true],
"center-position-x": [null, true],
"center-position-z": [null, true],
"source-position-x": [null, true],
"source-position-y": [null, true],
"source-position-z": [null, true],
"detector-position-x": [null, true],
"detector-position-y": [null, true],
"detector-position-z": [null, true],
"detector-angle-x": [null, true],
"detector-angle-y": [null, true],
"detector-angle-z": [null, true],
"axis-angle-x": [null, true],
"axis-angle-y": [null, true],
"axis-angle-z": [null, true],
"volume-angle-x": [null, true],
"volume-angle-y": [null, true],
"volume-angle-z": [null, true]
}
},
"horizontal-interpolate": {
"port-captions": {
"input": {
"0": "image",
"1": "mask"
},
"output": {
"0": ""
}
}
},
"read": {
"hidden-properties": [
"start",
"step",
"y",
"height",
"y-step",
"convert",
"raw-width",
"raw-height",
"raw-bitdepth",
"raw-pre-offset",
"raw-post-offset",
"type",
"retries",
"retry-timeout"
]
},
"write": {
"hidden-properties": [
"counter-start",
"counter-step",
"bytes-per-file",
"append",
"bits",
"minimum",
"maximum",
"rescale",
"jpeg-quality",
"tiff-bigtiff"
]
}
}
}
tofu-0.12.0/tofu/flow/execution.py 0000664 0000000 0000000 00000020256 14237137211 0017056 0 ustar 00root root 0000000 0000000 import gi
import logging
import networkx as nx
gi.require_version('Ufo', '0.0')
from gi.repository import Ufo
from PyQt5.QtCore import QObject, pyqtSignal
from qtpynodeeditor import PortType
from threading import Thread
from tofu.flow.models import ARRAY_DATA_TYPE, UFO_DATA_TYPE, UfoTaskModel
from tofu.flow.util import FlowError
LOG = logging.getLogger(__name__)
class UfoExecutor(QObject):
"""Class holding GPU resources and organizing UFO graph execution."""
number_of_inputs_changed = pyqtSignal(int) # Number of inputs has been determined
processed_signal = pyqtSignal(int) # Image has been processed
execution_started = pyqtSignal() # Graph execution started
execution_finished = pyqtSignal() # Graph execution finished
exception_occured = pyqtSignal(str)
def __init__(self):
super().__init__(parent=None)
self._resources = Ufo.Resources()
self._reset()
# If True only log the exception and emit the signal but don't re-raise it in the executing
# thread
self.swallow_run_exceptions = False
def _reset(self):
self._aborted = False
self._schedulers = []
self.num_generated = 0
def abort(self):
LOG.debug('Execution aborted')
try:
self._aborted = True
for scheduler in self._schedulers:
scheduler.abort()
finally:
self.execution_finished.emit()
def on_processed(self, ufo_task):
self.processed_signal.emit(self.num_generated)
self.num_generated += 1
def setup_ufo_graph(self, graph, gpu=None, region=None, signalling_model=None):
ufo_graph = Ufo.TaskGraph()
ufo_tasks = {}
for source, dest, ports in graph.edges.data():
if hasattr(source, 'create_ufo_task') and hasattr(dest, 'create_ufo_task'):
if dest not in ufo_tasks:
ufo_tasks[dest] = dest.create_ufo_task(region=region)
if source not in ufo_tasks:
ufo_tasks[source] = source.create_ufo_task(region=region)
ufo_graph.connect_nodes_full(ufo_tasks[source],
ufo_tasks[dest],
ports[PortType.input])
LOG.debug(f'{source.name}->{dest.name}@{ports[PortType.input]}')
if source == signalling_model:
ufo_tasks[source].connect('generated', self.on_processed)
if gpu is not None:
for task in ufo_tasks.values():
if task.uses_gpu():
task.set_proc_node(gpu)
return ufo_graph
def _run_ufo_graph(self, ufo_graph, use_fixed_scheduler):
LOG.debug(f'Executing graph, fixed scheduler: {use_fixed_scheduler}')
try:
scheduler = Ufo.FixedScheduler() if use_fixed_scheduler else Ufo.Scheduler()
self._schedulers.append(scheduler)
scheduler.set_resources(self._resources)
scheduler.run(ufo_graph)
LOG.info(f'Execution time: {scheduler.props.time} s')
except Exception as e:
# Do not continue execution of other batches
self._aborted = True
LOG.error(e, exc_info=True)
self.exception_occured.emit(str(e))
if not self.swallow_run_exceptions:
raise e
def check_graph(self, graph):
"""
Check that *graph* starts with an UfoTaskModel and ends with either that or an UfoModel
but no UfoTaskModel successor exists (there can be only one UFO path in the graph).
"""
roots = [n for n in graph.nodes if graph.in_degree(n) == 0]
leaves = [n for n in graph.nodes if graph.out_degree(n) == 0]
for root in roots:
for leave in leaves:
for path in nx.simple_paths.all_simple_paths(graph, root, leave):
if not isinstance(path[0], UfoTaskModel):
raise FlowError('Flow must start with an UFO node')
ufo_ended = False
for (i, succ) in enumerate(path[1:]):
model = path[i]
edge_data = graph.get_edge_data(model, succ)
if len(edge_data) > 1:
# There cannot be multiple edges between nodes
raise FlowError('Multiple edges not allowed but detected '
'between {model} and {succ}')
out_index = edge_data[0]['output']
# We don't need to check if input data type is ARRAY_DATA_TYPE because
# UFO_DATA_TYPE cannot be connected to ARRAY_DATA_TYPE in the scene
if ufo_ended:
# From now on only non-UFO tasks are allowed
if model.data_type['output'][out_index] != ARRAY_DATA_TYPE:
raise FlowError('After a non-UFO node cannot come another UFO node')
elif model.data_type['output'][out_index] != UFO_DATA_TYPE:
# Output is non-UFO, UFO ends here
ufo_ended = True
def run(self, graph):
self._reset()
self.check_graph(graph)
gpus = self._resources.get_gpu_nodes()
num_inputs = -1
signalling_model = None
for model in graph.nodes:
if graph.in_degree(model) == 0:
if 'number' in model:
current = model['number']
if current > num_inputs:
num_inputs = current
signalling_model = model
batches = [[(None, None)]]
gpu_splitting_model = None
gpu_splitting_models = get_gpu_splitting_models(graph)
if len(gpu_splitting_models) > 1:
# There cannot be multiple splitting models
raise FlowError('Only one gpu splitting model is allowed')
elif gpu_splitting_models:
gpu_splitting_model = gpu_splitting_models[0]
batches = gpu_splitting_model.split_gpu_work(self._resources.get_gpu_nodes())
for model in graph.nodes:
# Reset internal model state
if hasattr(model, 'reset_batches'):
model.reset_batches()
LOG.debug(f'{len(batches)} batches: {batches}')
if signalling_model:
self.number_of_inputs_changed.emit(len(batches) * num_inputs)
LOG.debug(f'Number of inputs: {len(batches) * num_inputs}, defined '
f'by {signalling_model}')
def execute_batches():
self.execution_started.emit()
try:
for (i, parallel_batch) in enumerate(batches):
LOG.info(f'starting batch {i}: {parallel_batch}')
threads = []
for gpu_index, region in parallel_batch:
if self._aborted:
break
gpu = None if gpu_index is None else gpus[gpu_index]
ufo_graph = self.setup_ufo_graph(graph, gpu=gpu, region=region,
signalling_model=signalling_model)
t = Thread(target=self._run_ufo_graph,
args=(ufo_graph,
len(gpu_splitting_models) > 0))
t.daemon = True
threads.append(t)
t.start()
for t in threads:
t.join()
if self._aborted:
break
except Exception as e:
LOG.error(e, exc_info=True)
self.exception_occured.emit(str(e))
raise e
finally:
self.execution_finished.emit()
gt = Thread(target=execute_batches)
gt.daemon = True
gt.start()
def get_gpu_splitting_models(graph):
gpu_splitting_models = []
for model in graph.nodes:
if isinstance(model, UfoTaskModel) and model.can_split_gpu_work:
gpu_splitting_models.append(model)
return gpu_splitting_models
tofu-0.12.0/tofu/flow/filedirdialog.py 0000664 0000000 0000000 00000001313 14237137211 0017642 0 ustar 00root root 0000000 0000000 import os
from PyQt5.QtWidgets import QFileDialog
class FileDirDialog(QFileDialog):
"""
A workaround for being able to select both files and directories.
Source:
https://stackoverflow.com/questions/27520304/qfiledialog-that-accepts-a-single-file-or-a-single-directory
"""
def __init__(self, parent=None):
super().__init__(parent=parent)
self.setOption(QFileDialog.DontUseNativeDialog)
self.setFileMode(QFileDialog.Directory)
self.currentChanged.connect(self._selected)
def _selected(self, name):
if os.path.isdir(name):
self.setFileMode(QFileDialog.Directory)
else:
self.setFileMode(QFileDialog.ExistingFile)
tofu-0.12.0/tofu/flow/main.py 0000664 0000000 0000000 00000054341 14237137211 0016001 0 ustar 00root root 0000000 0000000 import json
import logging
import os
import pathlib
import sys
from PyQt5.QtCore import Qt, QObject, QPoint, pyqtSignal
from PyQt5.QtWidgets import (QApplication, QFileDialog, QWidget, QVBoxLayout, QMenuBar,
QMessageBox, QProgressBar, QMainWindow, QStyle)
from qtpynodeeditor import DataModelRegistry, FlowView
from xdg import xdg_data_home
from tofu.flow.execution import UfoExecutor
from tofu.flow.models import (BaseCompositeModel, get_composite_model_classes_from_json,
get_composite_model_classes, get_ufo_model_classes, ImageViewerModel,
UfoGeneralBackprojectModel, UfoMemoryOutModel, UfoOpenCLModel,
UfoReadModel, UfoRetrievePhaseModel, UfoWriteModel)
from tofu.flow.scene import UfoScene
from tofu.flow.propertylinkswidget import PropertyLinks
from tofu.flow.runslider import RunSlider
from tofu.flow.util import FlowError
LOG = logging.getLogger(__name__)
class ApplicationWindow(QMainWindow):
def __init__(self, ufo_scene):
super().__init__()
self.ufo_scene = ufo_scene
self.property_links_widget = PropertyLinks(ufo_scene.node_model,
ufo_scene.property_links_model,
parent=self)
self.run_slider = RunSlider(parent=self)
self.executor = UfoExecutor()
self.console = None
self.run_slider_key = (None, None)
self.last_dirs = {'scene': None, 'composite': None}
self._creating_composite = False
self._expanding_composite = False
central_widget = QWidget()
self.setCentralWidget(central_widget)
main_layout = QVBoxLayout(central_widget)
self.flow_view = FlowView(self.ufo_scene)
self.progress_bar = QProgressBar()
self.progress_bar.setMinimum(0)
menu_bar = QMenuBar()
flow_menu = menu_bar.addMenu('Flow')
new_action = flow_menu.addAction("New")
new_action.setShortcut('Ctrl+N')
new_action.triggered.connect(self.on_new)
save_action = flow_menu.addAction("Save")
save_action.setShortcut('Ctrl+S')
save_action.triggered.connect(self.on_save)
load_action = flow_menu.addAction("Open")
load_action.setShortcut('Ctrl+O')
load_action.triggered.connect(self.on_open)
self.run_action = flow_menu.addAction(self.style().standardIcon(QStyle.SP_MediaPlay),
'Run')
self.run_action.setShortcut('Ctrl+R')
self.run_action.triggered.connect(self.on_run)
abort_action = flow_menu.addAction(self.style().standardIcon(QStyle.SP_MediaStop), 'Abort')
abort_action.setShortcut('Ctrl+Shift+X')
abort_action.triggered.connect(self.executor.abort)
exit_action = flow_menu.addAction('Exit')
exit_action.setShortcut('Ctrl+Q')
exit_action.triggered.connect(self.close)
# Nodes submenu
selection_menu = menu_bar.addMenu('Nodes')
selection_menu.setToolTipsVisible(True)
selection_menu.aboutToShow.connect(self.on_selection_menu_about_to_show)
self.skip_action = selection_menu.addAction('Skip Toggle')
self.skip_action.setShortcut('S')
self.skip_action.triggered.connect(self.ufo_scene.skip_nodes)
auto_fill_action = selection_menu.addAction('Auto fill')
auto_fill_action.triggered.connect(self.ufo_scene.auto_fill)
copy_action = selection_menu.addAction("Duplicate")
copy_action.setShortcut('Ctrl+Shift+D')
copy_action.triggered.connect(self.ufo_scene.copy_nodes)
# Composite
create_composite_action = selection_menu.addAction("Create Composite")
create_composite_action.setShortcut('Ctrl+Shift+C')
create_composite_action.triggered.connect(self.on_create_composite)
import_composites_action = selection_menu.addAction("Import Composites")
import_composites_action.setToolTip('Import one or more composite nodes '
'from a file or files')
import_composites_action.setShortcut('Ctrl+I')
import_composites_action.triggered.connect(self.on_import_composites)
self.export_composite_action = selection_menu.addAction("Export Composite")
self.export_composite_action.triggered.connect(self.on_export_composite)
self.edit_composite_action = selection_menu.addAction("Edit Composite")
self.edit_composite_action.triggered.connect(self.on_edit_composite)
self.expand_composite_action = selection_menu.addAction("Expand Composite")
self.expand_composite_action.setShortcut('Ctrl+Shift+E')
self.expand_composite_action.triggered.connect(self.on_expand_composite)
view_menu = menu_bar.addMenu('View')
reset_view_action = view_menu.addAction("Reset Zoom")
reset_view_action.setShortcut('Ctrl+0')
reset_view_action.triggered.connect(self.on_reset_view)
property_links_action = view_menu.addAction("Link Properties")
property_links_action.setShortcut('Ctrl+L')
property_links_action.triggered.connect(self.on_property_links_action)
console_action = view_menu.addAction("Open Python Console")
console_action.setShortcut('Ctrl+Shift+P')
console_action.triggered.connect(self.on_console_action)
run_slider_action = view_menu.addAction("Run Slider")
run_slider_action.setShortcut('Ctrl+Shift+S')
run_slider_action.triggered.connect(self.on_run_slider_action)
self.fix_run_slider = view_menu.addAction("Fix Run Slider")
self.fix_run_slider.setCheckable(True)
self.fix_run_slider.setShortcut('Ctrl+Alt+Shift+S')
main_layout.addWidget(menu_bar)
main_layout.addWidget(self.flow_view)
main_layout.addWidget(self.progress_bar)
main_layout.setContentsMargins(0, 0, 0, 0)
main_layout.setSpacing(0)
self.resize(1280, 1000)
# Signals
self.executor.exception_occured.connect(self.on_exception_occured)
self.executor.execution_finished.connect(self.on_execution_finished)
self.executor.number_of_inputs_changed.connect(self.on_number_of_inputs_changed)
self.executor.processed_signal.connect(self.on_processed)
self.ufo_scene.node_deleted.connect(self.on_node_deleted)
self.ufo_scene.nodes_duplicated.connect(self.on_nodes_duplicated)
self.ufo_scene.item_focus_in.connect(self.on_item_focus_in)
self.run_slider.value_changed.connect(self.on_run_slider_value_changed)
self.setWindowTitle('tofu flow')
def on_save(self):
if self.last_dirs['scene']:
path = self.last_dirs['scene']
else:
path = os.path.join(xdg_data_home(), 'tofu', 'flows')
if not os.path.exists(path):
os.makedirs(path)
file_name, _ = QFileDialog.getSaveFileName(self,
"Select File Name",
str(path),
"Flow Scene Files (*.flow)")
if file_name:
self.last_dirs['scene'] = os.path.dirname(file_name)
self.ufo_scene.save(file_name)
def on_new(self):
self.run_slider.reset()
self.ufo_scene.clear_scene()
self.setWindowTitle('tofu flow')
def on_open(self):
if self.last_dirs['scene']:
path = self.last_dirs['scene']
else:
path = os.path.join(xdg_data_home(), 'tofu', 'flows')
if not os.path.exists(path):
path = pathlib.Path.home()
file_name, _ = QFileDialog.getOpenFileName(self,
"Open Flow Scene",
str(path),
"Flow Scene Files (*.flow)")
if file_name:
self.last_dirs['scene'] = os.path.dirname(file_name)
self.ufo_scene.load(file_name)
self.run_slider.reset()
self.setWindowTitle(file_name)
def on_exception_occured(self, text):
msg = QMessageBox(parent=self)
msg.setIcon(QMessageBox.Critical)
msg.setText(text)
msg.setWindowTitle("Error")
msg.exec_()
def on_number_of_inputs_changed(self, value):
self.progress_bar.setMaximum(value)
def on_processed(self, value):
self.progress_bar.setValue(value + 1)
def on_node_deleted(self, node):
slider_model, prop_name = self.run_slider_key
if slider_model:
if (isinstance(node.model, BaseCompositeModel)
and node.model.is_model_inside(slider_model)
and not (self._expanding_composite or self._creating_composite)):
self.run_slider.reset()
self.run_slider_key = (None, None)
elif node.model == slider_model and not self._creating_composite:
self.run_slider.reset()
self.run_slider_key = (None, None)
def on_nodes_duplicated(self, selected_nodes, new_nodes):
min_y = float('inf')
y_1 = float('-inf')
for node in selected_nodes:
height = node.model.embedded_widget().height()
y = node.graphics_object.y()
if y < min_y:
min_y = y
if y + height > y_1:
y_1 = y + height
for node in selected_nodes:
dy = node.graphics_object.y() - min_y
new_pos = QPoint(int(node.graphics_object.x()), int(dy + y_1 + 100))
new_nodes[node].graphics_object.setPos(new_pos)
def on_item_focus_in(self, item, prop_name, caption, model):
if not self.fix_run_slider.isChecked() or not self.run_slider.view_item:
if self.run_slider.setup(item):
self.run_slider_key = (model, prop_name)
self.run_slider.setWindowTitle(f'{caption}->{prop_name}')
def on_selection_menu_about_to_show(self):
composites = False
num_selected = len(self.ufo_scene.selected_nodes())
for node in self.ufo_scene.selected_nodes():
if isinstance(node.model, BaseCompositeModel):
composites = True
break
self.edit_composite_action.setEnabled(num_selected == 1 and composites)
self.export_composite_action.setEnabled(num_selected == 1 and composites)
self.expand_composite_action.setEnabled(composites)
self.skip_action.setEnabled(self.ufo_scene.selected_nodes() != [])
def on_edit_composite(self):
if self.ufo_scene.is_selected_one_composite():
# Check again in case this was invoked by the keyboard shortcut
node = self.ufo_scene.selected_nodes()[0]
node.model.edit_in_window(self)
def on_create_composite(self):
self._creating_composite = True
try:
path = None
prop_name = self.run_slider_key[1]
if self.run_slider_key[0]:
for node in self.ufo_scene.selected_nodes():
if isinstance(node.model, BaseCompositeModel):
if node.model.is_model_inside(self.run_slider_key[0]):
path = node.model.get_path_from_model(self.run_slider_key[0])
elif node.model == self.run_slider_key[0]:
path = [self.run_slider_key[0]]
composite_model = self.ufo_scene.create_composite().model
if path:
str_path = [model.caption for model in path]
new_model = composite_model.get_model_from_path(str_path)
new_view_item = new_model.get_view_item(prop_name)
# Do not make complete setup, that would reset limits, just update the view item
self.run_slider.view_item = new_view_item
self.run_slider_key = (new_model, prop_name)
title = '->'.join([composite_model.caption] + str_path + [prop_name])
self.run_slider.setWindowTitle(title)
finally:
self._creating_composite = False
def on_expand_composite(self):
self._expanding_composite = True
try:
slider_model, prop_name = self.run_slider_key
for node in self.ufo_scene.selected_nodes():
if isinstance(node.model, BaseCompositeModel):
if slider_model:
str_path = None
if node.model.is_model_inside(slider_model):
str_path = [model.caption for model in
node.model.get_path_from_model(slider_model)]
new_nodes = self.ufo_scene.expand_composite(node)[0]
# Pass the new node to the run slider if it was contained in this composite
if slider_model and str_path:
if slider_model.caption in new_nodes:
# runslider linked to a simple node after expanstion
slider_model = new_nodes[slider_model.caption].model
self.run_slider_key = (slider_model, prop_name)
new_view_item = slider_model.get_view_item(prop_name)
# Do not make complete setup, that would reset limits, just update the
# view item
self.run_slider.view_item = new_view_item
self.run_slider.setWindowTitle(f'{slider_model.caption}->{prop_name}')
else:
# runslider linked to another composite node (nesting) after expanstion
for node in new_nodes.values():
if isinstance(node.model, BaseCompositeModel):
if node.model.contains_path(str_path[2:]):
new_model = node.model.get_model_from_path(str_path[2:])
self.run_slider_key = (new_model, prop_name)
new_view_item = new_model.get_view_item(prop_name)
# Do not make complete setup, that would reset limits, just
# update the view item
self.run_slider.view_item = new_view_item
title = '->'.join(str_path[1:] + [prop_name])
self.run_slider.setWindowTitle(title)
self.run_slider_key = (new_model, prop_name)
break
finally:
self._expanding_composite = False
def on_import_composites(self):
if self.last_dirs['composite']:
path = self.last_dirs['composite']
else:
path = os.path.join(xdg_data_home(), 'tofu', 'flows', 'composites')
if not os.path.exists(path):
path = pathlib.Path.home()
file_names, _ = QFileDialog.getOpenFileNames(self,
"Select File Names",
str(path),
"Composite Model Files (*.cm)")
if not file_names:
return
self.last_dirs['composite'] = os.path.dirname(file_names[0])
overwriting = {}
for file_name in file_names:
LOG.debug(f'Loading composite from {file_name}')
with open(file_name, 'r') as f:
state = json.load(f)
for model in get_composite_model_classes_from_json(state):
if model.name in self.ufo_scene.registry.registered_model_creators():
overwriting[model.name] = os.path.basename(file_name)
self.ufo_scene.registry.register_model(model,
category='Composite',
registry=self.ufo_scene.registry)
if overwriting:
msg = QMessageBox(parent=self)
msg.setIcon(QMessageBox.Warning)
msg.setText('Composite nodes with same names detected. Files from which '
'the nodes have been loaded are listed in details.')
msg.setDetailedText('\n'.join([f'Node name "{name}" from file "{file_name}"'
for (name, file_name) in overwriting.items()]))
msg.setWindowTitle('Warning')
msg.exec_()
def export_composite(self, node, file_name):
state = node.model.save()
with open(file_name, 'w') as f:
json.dump(state, f, indent=4)
def on_export_composite(self):
if not self.ufo_scene.is_selected_one_composite():
# Check again in case this was invoked by the keyboard shortcut
return
if self.last_dirs['composite']:
path = self.last_dirs['composite']
else:
path = os.path.join(xdg_data_home(), 'tofu', 'flows', 'composites')
if not os.path.exists(path):
os.makedirs(path)
file_name, _ = QFileDialog.getSaveFileName(self,
"Select File Name",
str(path),
"Composite Model Files (*.cm)")
if file_name:
self.last_dirs['composite'] = os.path.dirname(file_name)
if not file_name.endswith('.cm'):
file_name += '.cm'
self.export_composite(self.ufo_scene.selected_nodes()[0], file_name)
def on_reset_view(self):
for view in self.ufo_scene.views():
transform = view.transform()
transform.reset()
view.setTransform(transform)
def on_property_links_action(self):
self.property_links_widget.show()
# Make sure it goes to the front if it is currently burried under other windows
self.property_links_widget.raise_()
def on_console_action(self):
if self.console:
self.console.show()
return
try:
from pyqtconsole.console import PythonConsole
from pyqtconsole.highlighter import format
self.console = PythonConsole(formats={
'keyword': format('darkBlue', 'bold')
})
self.console.setWindowFlag(Qt.SubWindow, True)
self.console.ctrl_d_exits_console(True)
self.console.push_local_ns('scene', self.ufo_scene)
self.console.resize(640, 480)
self.console.show()
self.console.eval_queued()
except ImportError as e:
LOG.error(e, exc_info=True)
self.on_exception_occured(str(e))
def on_run_slider_action(self):
if not self.run_slider.view_item:
msg = QMessageBox(parent=self)
msg.setIcon(QMessageBox.Information)
msg.setText('Click on an input field in the flow to connect the slider')
msg.exec_()
else:
self.run_slider.show()
# Make sure it goes to the front if it is currently burried under other windows
self.run_slider.raise_()
def on_run_slider_value_changed(self, value):
if self.run_action.isEnabled():
self.on_run()
def on_run(self):
graphs = self.ufo_scene.get_simple_node_graphs()
if len(graphs) != 1:
raise FlowError('Scene must contain one fully connected graph')
if not self.ufo_scene.is_fully_connected():
raise FlowError('Not all node ports are connected')
self.executor.run(graphs[0])
self.run_action.setEnabled(False)
self.ufo_scene.set_enabled(False)
def on_execution_finished(self):
self.progress_bar.reset()
self.run_action.setEnabled(True)
self.ufo_scene.set_enabled(True)
class GlobalExceptionHandler(QObject):
"""
Intercept exceptions, log them and inform user if they are UI-related. Emit a signal when the
error message should be shown to the user so that e.g. a message can be shown in the main
thread.
"""
exception_occured = pyqtSignal(str)
def excepthook(self, exc_type, exc_value, exc_traceback):
LOG.error(exc_value, exc_info=(exc_type, exc_value, exc_traceback))
if issubclass(exc_type, FlowError):
self.exception_occured.emit(str(exc_value))
def get_filled_registry():
registry = DataModelRegistry()
for model in get_ufo_model_classes():
category = 'Processing'
if model.num_ports['input'] == 0:
category = 'Input'
if model.num_ports['output'] == 0:
category = 'Output'
registry.register_model(model, category=category, scrollable=True)
registry.register_model(UfoGeneralBackprojectModel, category='Processing')
registry.register_model(UfoOpenCLModel, category='Processing')
registry.register_model(UfoRetrievePhaseModel, category='Processing')
registry.register_model(UfoMemoryOutModel, category='Data')
registry.register_model(ImageViewerModel, category='Output')
registry.register_model(UfoWriteModel, category='Output')
registry.register_model(UfoReadModel, category='Input')
for models in get_composite_model_classes():
for model in models:
if model.name not in registry.registered_model_creators():
registry.register_model(model, category='Composite', registry=registry)
return registry
def main():
app = QApplication(sys.argv)
scene = UfoScene(registry=get_filled_registry())
main_window = ApplicationWindow(scene)
# Exception interception
exception_handler = GlobalExceptionHandler()
exception_handler.exception_occured.connect(main_window.on_exception_occured)
# Do not use threading.excepthook because it needs at least python 3.8., i.e. all exceptions in
# threads have to be handled properly (logged, signal emitted so that a message can be displayed
# in the main thread to the user, see tofu.flow.execution for example).
sys.excepthook = exception_handler.excepthook
main_window.show()
sys.exit(app.exec_())
if __name__ == '__main__':
main()
tofu-0.12.0/tofu/flow/models.py 0000664 0000000 0000000 00000172625 14237137211 0016346 0 ustar 00root root 0000000 0000000 """
All classes needed for :class:`qtpynodeeditor.NodeDataModel` implementation of UFO and
composite tasks.
"""
import gi
import glob
import json
import logging
import networkx as nx
import numpy as np
import pkg_resources
import os
import re
gi.require_version('Ufo', '0.0')
from gi.repository import Ufo
from PyQt5 import QtCore
from PyQt5.QtCore import QObject, Qt, pyqtSignal
from PyQt5.QtGui import QDoubleValidator, QValidator
from PyQt5.QtWidgets import (QCheckBox, QComboBox, QGroupBox, QInputDialog, QLabel, QLineEdit,
QScrollArea, QWidget, QFileDialog, QFormLayout, QVBoxLayout, QMenu)
from qtpynodeeditor import (NodeData, NodeDataModel, NodeDataType, FlowScene, FlowView, Port,
PortType, opposite_port)
from threading import Lock
from tofu.flow.util import (CompositeConnection, FlowError, get_config_key, MODEL_ROLE, NODE_ROLE,
PROPERTY_ROLE, saved_kwargs)
from tofu.flow.filedirdialog import FileDirDialog
LOG = logging.getLogger(__name__)
UFO_PLUGIN_MANAGER = Ufo.PluginManager()
UFO_DATA_TYPE = NodeDataType(id="UfoBuffer", name=None)
ARRAY_DATA_TYPE = NodeDataType(id="NumpyArray", name=None)
class UfoIntValidator(QValidator):
"""Combined int and unsigned int validator."""
def __init__(self, minimum, maximum, parent=None):
super().__init__(parent=parent)
self.minimum = minimum
self.maximum = maximum
def bottom(self):
return self.minimum
def top(self):
return self.maximum
def validate(self, input_str, pos):
try:
if self.minimum <= int(input_str) <= self.maximum:
result = (QValidator.Acceptable, input_str, pos)
else:
result = (QValidator.Intermediate, input_str, pos)
except ValueError:
if not input_str or input_str == '-' and self.minimum < 0:
result = (QValidator.Intermediate, input_str, pos)
else:
result = (QValidator.Invalid, input_str, pos)
return result
class UfoRangeValidator(QValidator):
"""
Range separated by comma validator. *num_items* specifies how many numbers must be in the
string. *is_float* specifies if the numbers are floating point (integer or unsigned integer
otherwise).
"""
def __init__(self, num_items=None, is_float=True, parent=None):
super().__init__(parent=parent)
self.num_items = num_items
self.is_float = is_float
def validate(self, input_str, pos):
float_regexp = r'[+-]|[+-]?(\d+(\.\d*)?|\.\d*)([eE][+-]?\d*)?'
numbers = input_str.split(',')
intermediate = False
if self.num_items is not None and len(numbers) > self.num_items:
# Incorrect number of items
return (QValidator.Invalid, input_str, pos)
for (i, number) in enumerate(numbers):
number = number.lower().strip()
if ('e' in number or '.' in number) and not self.is_float:
# Integer expected
return (QValidator.Invalid, input_str, pos)
if self.is_float:
try:
float(number)
except:
if (not number or re.fullmatch(float_regexp, number)):
# Partial floating point number (e.g. ends with "e")
intermediate = True
continue
else:
return (QValidator.Invalid, input_str, pos)
else:
try:
int(number)
except:
if not number or number == '-':
intermediate = True
continue
else:
return (QValidator.Invalid, input_str, pos)
if intermediate or (self.num_items is not None and len(numbers) < self.num_items):
# Not enough arguments received or some numbers are incomplete
return (QValidator.Intermediate, input_str, pos)
return (QValidator.Acceptable, input_str, pos)
class ViewItem(QObject):
property_changed = pyqtSignal(QObject)
def __init__(self, widget, default_value=None, tooltip=''):
super().__init__(parent=None)
self.widget = widget
self.focus_info = False
if tooltip:
self.widget.setToolTip(tooltip)
if default_value is not None:
self.set(default_value)
def on_changed(self, *args):
"""
Only user interaction must emit signals in the descendants. Signal is emitted only if the
user input is valid.
"""
try:
self.get()
self.property_changed.emit(self)
except:
LOG.debug(f'{self}: invalid input')
def get(self):
...
def set(self, value):
...
class CheckBoxViewItem(ViewItem):
def __init__(self, checked=False, tooltip=''):
widget = QCheckBox()
super().__init__(widget, default_value=checked, tooltip=tooltip)
widget.clicked.connect(self.on_changed)
def get(self):
return self.widget.isChecked()
def set(self, value):
self.widget.setChecked(value)
class ComboBoxViewItem(ViewItem):
def __init__(self, items, default_value=None, tooltip=''):
widget = QComboBox()
for item in items:
widget.addItem(item)
super().__init__(widget, default_value=default_value, tooltip=tooltip)
widget.activated.connect(self.on_changed)
def get(self):
return self.widget.currentText()
def set(self, value):
self.widget.setCurrentText(value)
class FocusInterceptQLineEdit(QLineEdit):
focus_in = pyqtSignal(QObject)
def focusInEvent(self, event):
self.focus_in.emit(self)
return super().focusInEvent(event)
class QLineEditViewItem(ViewItem):
focus_in = pyqtSignal(QObject)
def __init__(self, default_value=None, tooltip='', intercept_focus=False):
if intercept_focus:
widget = FocusInterceptQLineEdit()
widget.focus_in.connect(self.on_focus_in)
else:
widget = QLineEdit()
super().__init__(widget, default_value=default_value, tooltip=tooltip)
if intercept_focus:
self.focus_info = True
widget.textEdited.connect(self.on_changed)
def on_focus_in(self, widget):
self.focus_in.emit(self)
def get(self):
return self.widget.text()
def set(self, value):
self.widget.setText(str(value))
class NumberQLineEditViewItem(QLineEditViewItem):
def __init__(self, minimum, maximum, default_value=None, tooltip=''):
if default_value < minimum or default_value > maximum:
raise ValueError(f'default value {default_value} not in limits [{minimum}, {maximum}]')
tooltip += ' (range: {} - {})'.format(minimum, maximum)
super().__init__(default_value=default_value, tooltip=tooltip, intercept_focus=True)
validator = QDoubleValidator(float(minimum), float(maximum), 100)
self.widget.setValidator(validator)
def get(self):
return float(super().get())
class IntQLineEditViewItem(QLineEditViewItem):
def __init__(self, minimum, maximum, default_value=None, tooltip=''):
if default_value < minimum or default_value > maximum:
raise ValueError(f'default value {default_value} not in limits [{minimum}, {maximum}]')
tooltip += ' (range: {} - {})'.format(minimum, maximum)
super().__init__(default_value=default_value, tooltip=tooltip, intercept_focus=True)
validator = UfoIntValidator(minimum, maximum)
self.widget.setValidator(validator)
def get(self):
return int(super().get())
class RangeQLineEditViewItem(QLineEditViewItem):
def __init__(self, default_value='', tooltip='', num_items=None, is_float=True):
super().__init__(default_value=default_value, tooltip=tooltip, intercept_focus=True)
validator = UfoRangeValidator(num_items=num_items, is_float=is_float)
self.widget.setValidator(validator)
def set(self, values):
text = ','.join([str(value) for value in values]) if values else ''
self.widget.setText(text)
def get(self):
text = super().get()
if text:
values = [float(num) for num in text.split(',')]
else:
values = []
return values
def get_ufo_qline_edit_item(glib_prop, default_value, range_num_items=None, range_is_float=True):
if glib_prop.value_type.name == 'GValueArray':
item = RangeQLineEditViewItem(tooltip=glib_prop.blurb, default_value=default_value,
num_items=range_num_items, is_float=range_is_float)
elif glib_prop.value_type.name in ['gdouble', 'gfloat']:
item = NumberQLineEditViewItem(glib_prop.minimum, glib_prop.maximum,
default_value=default_value,
tooltip=glib_prop.blurb)
elif hasattr(glib_prop, 'minimum') and hasattr(glib_prop, 'maximum'):
item = IntQLineEditViewItem(glib_prop.minimum, glib_prop.maximum,
default_value=default_value,
tooltip=glib_prop.blurb)
else:
item = QLineEditViewItem(default_value=str(default_value),
tooltip=glib_prop.blurb)
return item
class PropertyViewRecord:
"""Attribute-access to a view's item."""
def __init__(self, view_item, label, visible):
self.view_item = view_item
self.label = label
self.visible = visible
def __str__(self):
return repr(self)
def __repr__(self):
fmt = 'PropertyViewRecord(widget={}, visible={})'
return fmt.format(self.view_item.widget, self.visible)
class MultiPropertyViewRecord:
"""Attribute-access to a multiple property view's item."""
def __init__(self, model, widget, visible):
self.model = model
self.widget = widget
self.visible = visible
def __str__(self):
return repr(self)
def __repr__(self):
fmt = 'MultiPropertyViewRecord(model={}, widget={}, visible={})'
return fmt.format(self.model, self.widget, self.visible)
class PropertyView(QWidget):
property_changed = pyqtSignal(str, object)
item_focus_in = pyqtSignal(ViewItem, str)
def __init__(self, properties=None, parent=None, scrollable=True):
super().__init__(parent=parent)
form_layout = QFormLayout()
form_layout.setVerticalSpacing(0)
self._properties = {}
if properties:
for (name, (item, active)) in properties.items():
if name in self._properties:
raise ValueError("Item '{}' already exists".format(name))
# Set the parent properly, so that set_property_visible won't try to show the item
# widget and the label in their own windows before the view is shown
item.widget.setParent(self)
label = QLabel(name, parent=self)
form_layout.addRow(label, item.widget)
self._properties[name] = PropertyViewRecord(item, label, active)
self.set_property_visible(name, active)
item.property_changed.connect(self.on_property_changed)
if item.focus_info:
item.focus_in.connect(self.on_item_focus_in)
if scrollable:
widget = QWidget()
widget.setLayout(form_layout)
scroll = QScrollArea()
scroll.setWidget(widget)
scroll.setWidgetResizable(True)
main_layout = QVBoxLayout()
main_layout.addWidget(scroll)
self.setLayout(main_layout)
else:
self.setLayout(form_layout)
@property
def property_names(self):
return self._properties.keys()
def get_property(self, name):
return self._properties[name].view_item.get()
def set_property(self, name, value):
return self._properties[name].view_item.set(value)
def get_record(self, name):
return self._properties[name]
def on_property_changed(self, item):
# Get item's name
for (name, record) in self._properties.items():
if item == record.view_item:
break
self.property_changed.emit(name, item.get())
def on_item_focus_in(self, view_item):
for (name, it) in self._properties.items():
if it.view_item.widget == view_item.widget:
self.item_focus_in.emit(view_item, name)
break
def is_property_visible(self, name):
return self._properties[name].visible
def set_property_visible(self, name, visible):
self._properties[name].view_item.widget.setVisible(visible)
self._properties[name].label.setVisible(visible)
self._properties[name].visible = visible
def restore_properties(self, values):
for prop in self._properties:
if prop not in values:
LOG.debug(f'Property {prop} not stored, using default')
continue
value, visible = values[prop]
self.set_property(prop, value)
self.set_property_visible(prop, visible)
def export_properties(self):
values = {}
for prop in self._properties:
values[prop] = [self.get_property(prop), self.is_property_visible(prop)]
return values
def contextMenuEvent(self, event):
contextMenu = QMenu(self)
actions = {}
for name in list(self._properties.keys()):
action = contextMenu.addAction(name)
action.setCheckable(True)
action.setChecked(self._properties[name].visible)
actions[action] = name
contextMenu.addSeparator()
show_all_action = contextMenu.addAction('Show All')
hide_all_action = contextMenu.addAction('Hide All')
action = contextMenu.exec_(self.mapToGlobal(event.pos()))
if action:
if action in actions:
name = actions[action]
checked = action.isChecked()
self.set_property_visible(name, checked)
elif action == show_all_action:
for name in self._properties.keys():
self.set_property_visible(name, True)
elif action == hide_all_action:
for name in self._properties.keys():
self.set_property_visible(name, False)
class MultiPropertyView(QWidget):
def __init__(self, groups, parent=None):
super().__init__(parent=parent)
self._group_box_layout = QVBoxLayout()
main_layout = QVBoxLayout()
widget = QWidget()
widget.setLayout(self._group_box_layout)
scroll = QScrollArea()
scroll.setWidget(widget)
scroll.setWidgetResizable(True)
self.setLayout(main_layout)
main_layout.addWidget(scroll)
self._groups = {}
for (model, visible) in groups.items():
if isinstance(model, PropertyModel):
model_widget = QGroupBox(model.caption)
layout = QVBoxLayout()
model_widget.setLayout(layout)
layout.addWidget(model.embedded_widget())
else:
model_widget = QLabel(model.caption, parent=self)
record = MultiPropertyViewRecord(model, model_widget, visible)
self._groups[model.caption] = record
self._group_box_layout.addWidget(model_widget)
self.set_group_visible(model.caption, visible)
def __getitem__(self, key):
return self._groups[key].model
def __contains__(self, key):
return key in self._groups
def __iter__(self):
return iter(self._groups)
def export_groups(self):
values = {}
for name in self._groups:
state = self._groups[name].model.save()
values[name] = {'model': state,
'visible': self._groups[name].visible}
return values
def restore_groups(self, values):
for name in values:
self[name].restore(values[name]['model'])
self.set_group_visible(name, values[name]['visible'])
def set_group_visible(self, name, visible):
self._groups[name].widget.setVisible(visible)
self._groups[name].visible = visible
def is_group_visible(self, name):
return self._groups[name].visible
def contextMenuEvent(self, event):
contextMenu = QMenu(self)
actions = {}
for name in list(self._groups.keys()):
action = contextMenu.addAction(name)
action.setCheckable(True)
action.setChecked(self._groups[name].visible)
actions[action] = name
contextMenu.addSeparator()
show_all_action = contextMenu.addAction('Show All')
hide_all_action = contextMenu.addAction('Hide All')
action = contextMenu.exec_(self.mapToGlobal(event.pos()))
if action:
if action in actions:
name = actions[action]
checked = action.isChecked()
self.set_group_visible(name, checked)
elif action == show_all_action:
for name in self._groups.keys():
self.set_group_visible(name, True)
elif action == hide_all_action:
for name in self._groups.keys():
self.set_group_visible(name, False)
class UfoModel(NodeDataModel):
"""The root parent of all other models in tofu flow."""
data_type = UFO_DATA_TYPE
item_focus_in = pyqtSignal(QObject, str, str, NodeDataModel)
def __init__(self, style=None, parent=None):
super().__init__(style=style, parent=parent)
# This is the caption model wants to have when it's instantiated, however, it might
# get a different caption from the scene because the captions must be unique within
self.base_caption = self.caption
self.skip = False
def restore(self, state, restore_caption=False):
if restore_caption:
self.caption = state.get('caption', self.caption)
def save(self):
return {'caption': self.caption}
def double_clicked(self, parent):
...
def __repr__(self):
return f'UfoModel({self.caption})'
def __str__(self):
return repr(self)
class PropertyModel(UfoModel):
property_changed = pyqtSignal(UfoModel, str, object)
def __init__(self, style=None, parent=None, scrollable=True):
"""*properties* is a dictionary of name: ViewItem items."""
super().__init__(style=style, parent=parent)
properties = self.make_properties()
if properties:
self.properties = list(properties.keys())
self._view = PropertyView(properties=properties, scrollable=scrollable)
self._view.property_changed.connect(self.on_property_changed)
self._view.item_focus_in.connect(self.on_item_focus_in)
else:
self.properties = []
self._view = None
def __getitem__(self, key):
return self._view.get_property(key)
def __setitem__(self, key, value):
return self._view.set_property(key, value)
def __contains__(self, key):
return key in self.properties
def __iter__(self):
return iter(self.properties)
def get_view_item(self, name):
return self._view.get_record(name).view_item
def on_property_changed(self, name, value):
self.property_changed.emit(self, name, value)
def on_item_focus_in(self, item, name):
self.item_focus_in.emit(item, name, self.caption, self)
def make_properties(self):
"""*properties* is a dictionary of name: ViewItem items."""
return {}
def copy_properties(self):
properties = self.make_properties()
for (name, (item, active)) in properties.items():
item.set(self[name])
properties[name][-1] = self._view.is_property_visible(name)
return properties
def auto_fill(self):
"""Automatically fill properties (e.g. number of files, etc.)"""
...
def resizable(self):
return True
def embedded_widget(self) -> QWidget:
return self._view if self._view else None
def restore(self, state, restore_caption=True):
self._view.restore_properties(state['properties'])
super().restore(state, restore_caption=restore_caption)
def save(self):
state = super().save()
state['properties'] = self._view.export_properties()
return state
class UfoTaskModel(PropertyModel):
caption_visible = True
def __init__(self, task_name, style=None, parent=None, scrollable=True):
self._task_name = task_name
self.caption = ' '.join([item[0].upper() + item[1:] for item in self.name.split('_')])
self.needs_fixed_scheduler = False
self.can_split_gpu_work = False
super().__init__(style=style, parent=parent, scrollable=scrollable)
def make_properties(self):
hidden_properties = get_config_key('models', self._task_name, 'hidden-properties')
range_properties = get_config_key('models', self._task_name, 'range-properties', default={})
properties = {}
ufo_task = UFO_PLUGIN_MANAGER.get_task(self._task_name)
for prop in ufo_task.list_properties():
if prop.name == 'num-processed':
continue
default_value = getattr(ufo_task.props, prop.name)
if prop.value_type.name == 'gboolean':
item = CheckBoxViewItem(checked=default_value, tooltip=prop.blurb)
elif hasattr(prop, 'enum_class'):
items = [name.value_nick for name in default_value.__enum_values__.values()]
item = ComboBoxViewItem(items, default_value=default_value.value_nick,
tooltip=prop.blurb)
else:
range_num_items, range_is_float = range_properties.get(prop.name, (None, True))
item = get_ufo_qline_edit_item(prop, default_value=default_value,
range_num_items=range_num_items,
range_is_float=range_is_float)
visible = True
if hidden_properties and prop.name in hidden_properties:
visible = False
properties[prop.name] = [item, visible]
return properties
def create_ufo_task(self, region=None):
if self.expects_multiple_inputs and region is None:
raise UfoModelError(f'{self.caption} expects multiple inputs '
'but there is no node with such capability in the flow')
ufo_task = UFO_PLUGIN_MANAGER.get_task(self._task_name)
self._setup_ufo_task(ufo_task, region=region)
return ufo_task
def _setup_ufo_task(self, ufo_task, region=None):
for prop in self:
setattr(ufo_task.props, prop, self[prop])
def reset_batches(self):
"""
In case the model can process batches and has internal state depending on them, this is
where it can be re-set.
"""
pass
@property
def uses_gpu(self):
return UFO_PLUGIN_MANAGER.get_task(self._task_name).uses_gpu()
@property
def expects_multiple_inputs(self):
return False
def get_ufo_model_class(ufo_task_name):
# Use this to determine inputs and outputs but create a new object in the constructor in order
# to enable multiple instances having different parameter values
_ufo_task = UFO_PLUGIN_MANAGER.get_task(ufo_task_name)
ufo_task_num_inputs = _ufo_task.get_num_inputs()
ufo_task_num_outputs = int(_ufo_task.get_mode() & Ufo.TaskMode.SINK == 0)
class UfoAutoModel(UfoTaskModel):
name = ufo_task_name.replace('-', '_')
def __init__(self, style=None, parent=None, scrollable=True):
self.num_ports = {PortType.input: ufo_task_num_inputs,
PortType.output: ufo_task_num_outputs}
self.data_type = {}
self.port_caption = {}
self.port_caption_visible = {}
for port_type in (PortType.input, PortType.output):
self.data_type[port_type] = {}
self.port_caption[port_type] = {}
self.port_caption_visible[port_type] = {}
for i in range(self.num_ports[port_type]):
port_captions = get_config_key('models', ufo_task_name, 'port-captions')
if port_captions:
port_caption = port_captions[port_type][str(i)]
port_caption_visible = True if port_caption else False
else:
port_caption = ''
port_caption_visible = False
self.data_type[port_type][i] = UFO_DATA_TYPE
self.port_caption[port_type][i] = port_caption
self.port_caption_visible[port_type][i] = port_caption_visible
self.ufo_task = None
super().__init__(ufo_task_name, style=style, parent=parent, scrollable=scrollable)
return UfoAutoModel
class BaseCompositeModel(UfoModel):
# Move functionality which can go here from CompositeModel here
data_type = UFO_DATA_TYPE
def __init__(self, models, connections, links=None, registry=None, style=None, parent=None):
if registry is None:
# This has to be keyword argument because of the qtpynodeeditor's node creation
# mechanism, but the argument is actually required
raise AttributeError('registry must be provided')
super().__init__(style=style, parent=parent)
# Nodes in the edit pop-up window
self.window_parent = None
self._property_links_model = None
self._links = [] if links is None else links
self._slave_property_links = []
self._window_nodes = {}
self._other_scene = None
self._other_view = None
self.num_ports = {PortType.input: 0,
PortType.output: 0}
self.data_type = {PortType.input: {}, PortType.output: {}}
self.port_caption = {PortType.input: {}, PortType.output: {}}
self.port_caption_visible = {PortType.input: {}, PortType.output: {}}
groups = {}
self._registry = registry
self._models = {}
# Internal connections
self._connections = connections
# Composite port to subnode port mapping
self._inside_ports = {}
# Subnode port to composite port mapping
self._outside_ports = {}
for (name, state, visible, position) in models:
# Don't use the deafault registry creation because embedded PropertyModel must have
# scrollable set to False
cls, orig_kwargs = registry.registered_model_creators()[name]
# Don't mess with the original dictionary
kwargs = {orig_key: orig_value for (orig_key, orig_value) in orig_kwargs.items()}
if issubclass(cls, PropertyModel):
kwargs['scrollable'] = False
if 'num-inputs' in state:
kwargs['num_inputs'] = state['num-inputs']
model = cls(**kwargs)
model.restore(state)
self._models[model] = position
groups[model] = visible
model.item_focus_in.connect(self.on_item_focus_in)
for port_type in ['input', 'output']:
for index in range(model.num_ports[port_type]):
side = (model.caption, port_type, index)
if not any([conn.contains(*side) for conn in connections]):
i = self.num_ports[port_type]
self.data_type[port_type][i] = UFO_DATA_TYPE
port_caption = model.caption
if model.port_caption[port_type][index]:
port_caption += ':' + model.port_caption[port_type][index]
self.port_caption[port_type][i] = port_caption
self.port_caption_visible[port_type][i] = True
self._inside_ports[(port_type, i)] = (model, port_type, index)
self._outside_ports[side] = (port_type, i)
self.num_ports[port_type] += 1
self._view = MultiPropertyView(groups)
def __getitem__(self, key):
return self._view[key]
def __contains__(self, key):
return key in self._view
def __iter__(self):
return iter(self._view)
def __repr__(self):
return f'Composite(caption={self.caption}, models={sorted(list(iter(self._view)))})'
def __str__(self):
return repr(self)
def get_outside_port(self, unique_name, port_type, port_index):
return self._outside_ports[(unique_name, port_type, port_index)]
def get_model_and_port_index(self, port_type, port_index):
model, spt, index = self._inside_ports[(port_type, port_index)]
return (model, index)
def embedded_widget(self) -> QWidget:
return self._view if self._view else None
def resizable(self):
return True
def on_item_focus_in(self, item, name, caption, model):
self.item_focus_in.emit(item, name, self.caption + '->' + caption, model)
@property
def is_editing(self):
"""Is wubwindow open."""
return self._window_nodes != {}
@property
def property_links_model(self):
return self._property_links_model
@property_links_model.setter
def property_links_model(self, plm):
self._property_links_model = plm
for model in self._models:
if isinstance(model, BaseCompositeModel):
model.property_links_model = plm
def contains_path(self, path):
"""Is there a caption *path* inside this model."""
model = self
for caption in path:
if caption in model:
model = model[caption]
else:
return False
return True
def get_model_from_path(self, path):
"""*path* is caption path (str)."""
model = self
for caption in path:
model = model[caption]
return model
def is_model_inside(self, model):
"""Return True if *model* is inside at any level."""
paths = self.get_leaf_paths()
for path in paths:
for item in path:
if item == model:
return True
return False
def get_path_from_model(self, model):
"""*model* must be inside this composite model."""
paths = self.get_leaf_paths()
for path in paths:
for (i, item) in enumerate(path):
if item == model:
return path[:i + 1]
raise KeyError(f'{model} not inside')
def get_descendant_graph(self, in_subwindow=False):
"""
Get all descendant models recursively in case there are composite models inside this model.
If *in_subwindow* is True, return models shown to the user in the subwindow, otherwise the
ones created at class instantiation. For composites inside this one, if *in_subwindow* is
True return the subwindow models, but if it's not being edited instead raising an exception,
return the internal models.
"""
if in_subwindow and not self.is_editing:
raise ValueError('in_subwindow True but no subwindow open')
graph = nx.DiGraph()
def descend(parent):
if in_subwindow and parent.is_editing:
models = [node.model for node in parent._window_nodes.values()]
else:
models = [parent[key] for key in parent]
for model in models:
graph.add_edge(parent, model)
if isinstance(model, BaseCompositeModel):
descend(model)
descend(self)
return graph
def get_leaf_paths(self, in_subwindow=False):
graph = self.get_descendant_graph(in_subwindow=in_subwindow)
leaves = [node for node in graph.nodes if graph.out_degree(node) == 0]
paths = []
for leaf in leaves:
paths.append(list(nx.simple_paths.all_simple_paths(graph, self, leaf))[0])
return paths
def restore(self, state, restore_caption=True):
self._connections = [CompositeConnection(*args) for args in state['connections']]
self._view.restore_groups(state['models'])
super().restore(state, restore_caption=restore_caption)
def restore_links(self, node):
if self.property_links_model:
row = self.property_links_model.rowCount()
for items in self._links:
# A row can be restored only if no property from the state is in the link model
# yet
row_ok = True
for str_path in items:
prop_name = str_path[-1]
model = self.get_model_from_path(str_path[:-1])
if self.property_links_model.find_items([node, model, prop_name],
[NODE_ROLE, MODEL_ROLE, PROPERTY_ROLE]):
LOG.info(f'{str_path[-2]}->{prop_name} already in property links')
row_ok = False
break
if row_ok:
for (i, str_path) in enumerate(items):
model = self.get_model_from_path(str_path[:-1])
self.property_links_model.add_item(node, model, str_path[-1], row, i)
row += 1
def save(self):
state = {'name': self.name, 'caption': self.caption}
state['models'] = self._view.export_groups()
for (model, position) in self._models.items():
state['models'][model.caption]['position'] = position
# This is necessary for creating models from saved files
state['models'][model.caption]['name'] = model.name
state['connections'] = [conn.save() for conn in self._connections]
if self.property_links_model:
state['links'] = []
paths = self.get_leaf_paths()
models = [path[-1] for path in paths]
items = self.property_links_model.get_model_links(models)
for row in items.values():
# First item in the row is this model, skip it
state['links'].append([str_path[1:] for str_path in row])
return state
def on_connection_created(self, connection):
self._other_scene.connection_deleted.disconnect(self.on_connection_deleted)
self._other_scene.delete_connection(connection)
self._other_scene.connection_deleted.connect(self.on_connection_deleted)
def on_connection_deleted(self, connection):
self._other_scene.connection_created.disconnect(self.on_connection_created)
self._other_scene.restore_connection(connection.__getstate__())
self._other_scene.connection_created.connect(self.on_connection_created)
def double_clicked(self, parent):
self.edit_in_window(parent=parent)
def on_other_scene_double_clicked(self, node):
node.model.double_clicked(self._other_view)
def expand_into_graph(self, graph):
"""Expand to submodels in a *graph*, which is a networkx.DiGraph instance."""
name_to_model = {}
for model in self._models:
LOG.debug(f'Adding node {model.name}')
graph.add_node(model)
name_to_model[model.caption] = model
for conn in self._connections:
source = name_to_model[conn.from_unique_name]
dest = name_to_model[conn.to_unique_name]
LOG.debug(f'Adding edge {source.name}@{conn.from_port_index} -> '
f'{dest.name}@{conn.to_port_index}')
graph.add_edge(source, dest, input=conn.to_port_index, output=conn.from_port_index)
def _expand_into_scene(self, scene, original_nodes=None, restore_captions=False):
# unique name to node instance mapping
name_to_node = {}
for model in self._models:
if original_nodes and model.caption in original_nodes:
node = scene.restore_node(original_nodes[model.caption])
else:
with saved_kwargs(scene.registry, model.__getstate__()):
if restore_captions:
node = scene.create_node(model.__class__)
else:
# This is the main scene, links restoration takes place in expand_into_scene
# for all nodes including composites
node = scene.create_node(model.__class__, restore_links=False)
if isinstance(model, PropertyModel) or isinstance(model, BaseCompositeModel):
node.model.restore(model.save(), restore_caption=restore_captions)
if isinstance(node.model, BaseCompositeModel):
node.model.property_links_model = self.property_links_model
else:
node.model.restore(model.save())
name_to_node[model.caption] = node
if self._models[model] is not None:
node.position = (self._models[model]['x'], self._models[model]['y'])
for conn in self._connections:
f_node = name_to_node[conn.from_unique_name]
t_node = name_to_node[conn.to_unique_name]
f_port = f_node[PortType.output][conn.from_port_index]
t_port = t_node[PortType.input][conn.to_port_index]
scene.create_connection(f_port, t_port, check_cycles=False)
return name_to_node
def add_slave_links(self):
self._slave_property_links = []
if not self.property_links_model:
return
for node in self._window_nodes.values():
if isinstance(node.model, BaseCompositeModel):
paths = node.model.get_leaf_paths(in_subwindow=node.model._window_nodes != {})
else:
paths = [[node.model]]
# Propagate all signals from leaves to the original models
for path in paths:
str_path = [m.caption for m in path]
new_model = path[-1]
orig_model = self.get_model_from_path(str_path)
# Create a link from this node's model instances to the original root
# models in the link model (there can be other composites along the way to
# the root
root_model = self.property_links_model.get_root_model(orig_model)
if root_model:
prop_names = self.property_links_model.get_model_properties(root_model)
for prop_name in prop_names:
if (new_model, prop_name) not in self._slave_property_links:
# In order to remove slaves when the subwindow is closed, register
# the slaves with respect to the most nested composite
registering_model = path[-2] if len(path) > 1 else self
if registering_model.is_editing:
registering_model._slave_property_links.append((new_model,
prop_name))
registering_model.property_links_model.add_silent(new_model,
prop_name,
root_model,
prop_name)
if registering_model.window_parent:
# If the registering model has a parent, register also the
# models in it's internal model view
new_model = registering_model[path[-1].caption]
registering_model = registering_model.window_parent
registering_model._slave_property_links.append((new_model,
prop_name))
registering_model.property_links_model.add_silent(new_model,
prop_name,
root_model,
prop_name)
def edit_in_window(self, parent=None):
self._other_scene = FlowScene(registry=self._registry)
self._other_scene.node_double_clicked.connect(self.on_other_scene_double_clicked)
self._window_nodes = self._expand_into_scene(self._other_scene, restore_captions=True)
# Store references to parent composites
for node in self._window_nodes.values():
if isinstance(node.model, BaseCompositeModel):
node.model.window_parent = self
# Property links have to be registered with respect to the top composite model because
# it's property model's property model is registered in property links
window_parent = self
while window_parent.window_parent:
window_parent = window_parent.window_parent
window_parent.add_slave_links()
# Disable manipulation because the number of ports is fixed, so we can't e.g. internally
# connect two nodes and delete the newly occupied port from the composite node
self._other_scene.allow_node_creation = False
self._other_scene.allow_node_deletion = False
# There is no allow_connection_creation/deletion, so take care of it here
self._other_scene.connection_created.connect(self.on_connection_created)
self._other_scene.connection_deleted.connect(self.on_connection_deleted)
self._other_view = FlowView(self._other_scene, parent=parent)
self._other_view.setWindowFlag(Qt.Window, True)
self._other_view.closeEvent = self.view_close_event
self._other_view.setWindowTitle(self.name)
self._other_view.resize(900, 600)
self._other_view.show()
def view_close_event(self, event):
for node in self._window_nodes.values():
# Clse all composite children recursively first
if isinstance(node.model, BaseCompositeModel) and node.model.is_editing:
node.model._other_view.close()
node.model.window_parent = None
for (unique_name, node) in self._window_nodes.items():
self._view[unique_name].restore(node.model.save())
if self.property_links_model:
for (model, prop_name) in self._slave_property_links:
self.property_links_model.remove_silent(model, prop_name)
self._slave_property_links = []
self._window_nodes = {}
self._other_scene = None
self._other_view = None
def expand_into_scene(self, scene, composite_node, original_nodes=None):
"""
Expand this node into *scene* and replace *composite_node*'s connections with
connections going straight into its subnodes. Also create connections internal to this
node and update property links. *original_nodes* is a dictionary in form {caption:
node_state} which will be used for positioning of the replacing nodes
(scene.restore_node instead of scene.create_node will be called).
"""
assert self.property_links_model is not None
# Connections to external nodes
connections = []
# name_to_node is in format caption: new node dictionary
# Internal connections are handled in _expand_into_scene
name_to_node = self._expand_into_scene(scene, original_nodes=original_nodes,
restore_captions=False)
for port_type in [PortType.input, PortType.output]:
for index, port in composite_node[port_type].items():
if port.connections:
connection = port.connections[0]
outside_port = connection.valid_ports[opposite_port(port_type)]
internal_model, pt, pi = self._inside_ports[(port_type, index)]
connections.append((outside_port,
name_to_node[internal_model.caption][pt][pi]))
# Update property links
for (subcaption, subnode) in name_to_node.items():
if isinstance(subnode.model, BaseCompositeModel):
# Get all leaf PropertyModel instances
paths = subnode.model.get_leaf_paths()
else:
paths = [[subnode.model]]
# In case selected node is composite, replace all leaf node links
for path in paths:
str_path = [model.caption for model in path]
# Captions might have changed if subnode captions were equal to other captions
# in the scene and the composite node which is being replaced contains still the
# old ones
old_str_path = [subcaption] + str_path[1:]
old_model = composite_node.model.get_model_from_path(old_str_path)
self.property_links_model.replace_item(subnode, path[-1], old_model)
subnode.graphics_object.setSelected(True)
scene.remove_node(composite_node)
# Create outside connections only after the composite node has been deleted to prevent
# creating multiple connections per input port in the outside nodes
for outside, inside in connections:
scene.create_connection(outside, inside, check_cycles=False)
return name_to_node, connections
def get_composite_model_class(composite_name, models, connections, links=None):
if not composite_name:
raise UfoModelError('composite name must be specified')
class CompositeModel(BaseCompositeModel):
name = composite_name
data_type = UFO_DATA_TYPE
def __init__(self, style=None, parent=None, registry=None):
super().__init__(models, connections, links=links, registry=registry,
style=style, parent=parent)
model = CompositeModel
model.caption_visible = True
model.caption = composite_name
return model
class UfoGeneralBackprojectModel(UfoTaskModel):
name = 'general_backproject'
num_ports = {PortType.input: 1, PortType.output: 1}
data_type = UFO_DATA_TYPE
def __init__(self, style=None, parent=None, scrollable=True):
super().__init__('general-backproject', style=style, parent=parent, scrollable=scrollable)
self.needs_fixed_scheduler = True
self.can_split_gpu_work = True
def make_properties(self):
properties = super().make_properties()
slice_memory_coeff = NumberQLineEditViewItem(0.01, 1., default_value=0.8,
tooltip='Portion of used GPU memory')
properties['slice-memory-coeff'] = [slice_memory_coeff, False]
return properties
def split_gpu_work(self, gpus):
from tofu.genreco import make_runs, DTYPE_CL_SIZE
def check_region(region):
if not len(np.arange(*self[region])):
raise UfoModelError(f'Invalid {region} {self[region]}')
# Check if ranges are OK
check_region('region')
check_region('x-region')
check_region('y-region')
gpu_indices = range(len(gpus))
bpp = DTYPE_CL_SIZE[self['store-type']]
runs = make_runs(gpus, gpu_indices, self['x-region'], self['y-region'],
self['region'], bpp, slice_memory_coeff=self['slice-memory-coeff'])
return runs
def _setup_ufo_task(self, ufo_task, region=None):
separate = ['region', 'slice-memory-coeff']
task_props = [prop for prop in self if prop not in separate]
for prop in task_props:
setattr(ufo_task.props, prop, self[prop])
# Set region separately in case there are multiple inputs
current_region = self['region'] if region is None else region
setattr(ufo_task.props, 'region', current_region)
class UfoVaryingInputModel(UfoTaskModel):
"""Base class for models which can have varying number if inputs."""
def __init__(self, task_name, style=None, parent=None, scrollable=True, num_inputs=None,
dialog_title='Number of inputs', dialog_label='Number of inputs:'):
if not num_inputs:
num_inputs, ok = QInputDialog.getInt(parent,
dialog_title,
dialog_label,
value=1, minValue=1, maxValue=10, step=1)
if not ok:
raise UfoModelError('Number of inputs must be specified')
self.num_ports = {PortType.input: num_inputs, PortType.output: 1}
self.data_type = {PortType.output: {0: UFO_DATA_TYPE}}
self.port_caption = {PortType.output: {0: ''}}
self.port_caption_visible = {PortType.output: {0: False}}
self.data_type[PortType.input] = {}
self.port_caption[PortType.input] = {}
self.port_caption_visible[PortType.input] = {}
for i in range(num_inputs):
self.data_type[PortType.input][i] = UFO_DATA_TYPE
self.port_caption[PortType.input][i] = ''
self.port_caption_visible[PortType.input][i] = False
super().__init__(task_name, style=style, parent=parent, scrollable=scrollable)
def save(self):
state = super().save()
state['num-inputs'] = self.num_ports['input']
return state
class UfoOpenCLModel(UfoVaryingInputModel):
name = 'opencl'
def __init__(self, style=None, parent=None, scrollable=True, num_inputs=None):
super().__init__('opencl', style=style, parent=parent, scrollable=scrollable,
num_inputs=num_inputs)
def _setup_ufo_task(self, ufo_task, region=None):
for prop in self:
if prop in ['filename', 'source']:
# opencl task really needs NULL
value = self[prop] if self[prop] else None
else:
value = self[prop]
setattr(ufo_task.props, prop, value)
class UfoReadModel(UfoTaskModel):
name = 'read'
num_ports = {PortType.input: 0, PortType.output: 1}
data_type = UFO_DATA_TYPE
def __init__(self, style=None, parent=None, scrollable=True):
super().__init__('read', style=style, parent=parent, scrollable=scrollable)
def auto_fill(self):
import glob
import imageio
if os.path.isdir(self['path']):
paths = sorted(glob.glob(os.path.join(self['path'], '*')))
else:
paths = [self['path']]
num_images = 0
for path in paths:
try:
num_images += len(imageio.get_reader(path))
except:
LOG.error(f"Error reading '{path}'")
if not num_images:
raise UfoModelError(f"No images found in {self['path']}")
self['number'] = num_images
def double_clicked(self, parent):
current_path = self['path']
if not os.path.isdir(current_path):
current_path = os.path.dirname(current_path)
if not current_path:
current_path = QtCore.QDir.homePath()
dialog = FileDirDialog()
if dialog.exec_():
self['path'] = dialog.selectedFiles()[0]
def _setup_ufo_task(self, ufo_task, region=None):
for prop in self:
if prop != 'raw-bitdepth' or self['raw-bitdepth']:
setattr(ufo_task.props, prop, self[prop])
class UfoRetrievePhaseModel(UfoVaryingInputModel):
name = 'retrieve_phase'
def __init__(self, style=None, parent=None, scrollable=True, num_inputs=None):
super().__init__('retrieve-phase', style=style, parent=parent, scrollable=scrollable,
dialog_title='Multi-distance Setup', dialog_label='Number of distances:',
num_inputs=num_inputs)
def make_properties(self):
properties = super().make_properties()
# Override distance property based on how many inputs we expect
tooltip = properties['distance'][0].widget.toolTip()
item = RangeQLineEditViewItem(tooltip=tooltip, default_value=[],
num_items=self.num_ports['input'], is_float=True)
properties['distance'] = [item, True]
if self.num_ports['input'] > 1:
properties['method'][0].set('ctf_multidistance')
properties['method'][0].widget.setEnabled(False)
properties['distance-x'][0].widget.setEnabled(False)
properties['distance-y'][0].widget.setEnabled(False)
return properties
class UfoWriteModel(UfoTaskModel):
name = 'write'
num_ports = {PortType.input: 1, PortType.output: 0}
data_type = UFO_DATA_TYPE
def __init__(self, style=None, parent=None, scrollable=True):
super().__init__('write', style=style, parent=parent, scrollable=scrollable)
def double_clicked(self, parent):
current_path = os.path.dirname(self['filename'])
if not current_path:
current_path = QtCore.QDir.homePath()
file_name, _ = QFileDialog.getSaveFileName(None, "Select File Name", current_path)
if file_name:
self['filename'] = file_name
@property
def expects_multiple_inputs(self):
return '{region}' in self['filename']
def _setup_ufo_task(self, ufo_task, region=None):
if region is not None and not self.expects_multiple_inputs:
raise UfoModelError('Write got region without enabling multiple inputs. '
'Add {region} somewhere in the "filename" field to enable it.')
super()._setup_ufo_task(ufo_task, region=region)
filename = self['filename']
if region is not None and self.expects_multiple_inputs:
filename = filename.format(region=region[0])
setattr(ufo_task.props, 'filename', filename)
class _Batch(QObject):
finished = pyqtSignal(int)
def __init__(self, ufo_task, shape, batch_id):
super().__init__(parent=None)
self.batch_id = batch_id
self.data = np.empty(shape, dtype=np.float32)
ptr = self.data.__array_interface__['data'][0]
ufo_task.props.pointer = ptr
ufo_task.props.max_size = self.data.nbytes
ufo_task.connect('processed', self._on_processed)
self.num_processed = 0
def _on_processed(self, ufo_task):
self.num_processed += 1
if self.num_processed == self.data.shape[0]:
self.finished.emit(self.batch_id)
class UfoMemoryOutModel(UfoTaskModel):
name = 'memory_out'
num_ports = {PortType.input: 1, PortType.output: 1}
data_type = {PortType.input: {0: UFO_DATA_TYPE},
PortType.output: {0: ARRAY_DATA_TYPE}}
port_caption = {PortType.input: {0: ''},
PortType.output: {0: ''}}
port_caption_visible = {PortType.input: {0: False},
PortType.output: {0: False}}
def __init__(self, style=None, parent=None, scrollable=True):
self._lock = Lock()
self.reset_batches()
super().__init__('memory-out', style=style, parent=parent, scrollable=scrollable)
@property
def expects_multiple_inputs(self):
return self['number'] == '{region}'
def make_properties(self):
width_item = IntQLineEditViewItem(0, 1000000, default_value=0, tooltip='Input width')
height_item = IntQLineEditViewItem(0, 1000000, default_value=0, tooltip='Input height')
depth_item = IntQLineEditViewItem(0, 1000000, default_value=1,
tooltip='Input depth (for 2D images should be 1)')
number_item = QLineEditViewItem(default_value=1, tooltip='Number of inputs')
properties = {'width': [width_item, True],
'height': [height_item, True],
'depth': [depth_item, True],
'number': [number_item, True]}
return properties
def consume_batch(self, batch_id):
def consume(current_batch):
LOG.debug(f'{self.caption}: consuming {current_batch.batch_id} (caller {batch_id})')
self._current_data = current_batch.data
self.data_updated.emit(0)
# Free memory up
self._batches[self._expecting_id] = None
with self._lock:
if self._expecting_id == batch_id:
consume(self._batches[self._expecting_id])
self._expecting_id += 1
while self._expecting_id in self._waiting_list:
consume(self._batches[self._expecting_id])
del self._waiting_list[self._waiting_list.index(self._expecting_id)]
self._expecting_id += 1
else:
LOG.debug(f'{self.caption}: putting {batch_id} on waiting list')
self._waiting_list.append(batch_id)
def out_data(self, port: int) -> NodeData:
LOG.debug(f'{self.caption}: out_data shape:'
f'{None if self._current_data is None else self._current_data.shape}')
return self._current_data
def reset_batches(self):
self._batches = []
self._waiting_list = []
self._expecting_id = 0
self._current_data = None
def _setup_ufo_task(self, ufo_task, region=None):
if region is not None and not self.expects_multiple_inputs:
raise UfoModelError('Memory Out got region without enabling multiple inputs. '
'Type {region} in the "number" field to enable it.')
number = int(self['number']) if region is None else len(np.arange(*region))
shape = (number, self['height'], self['width'])
with self._lock:
batch = _Batch(ufo_task, shape, len(self._batches))
self._batches.append(batch)
batch.finished.connect(self.consume_batch)
class ImageViewerModel(UfoModel):
name = 'image_viewer'
caption = 'Image Viewer'
num_ports = {PortType.input: 1,
PortType.output: 0,
}
data_type = ARRAY_DATA_TYPE
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._node_data = None
from tofu.flow.viewer import ImageViewer
self._widget = ImageViewer()
self._reset = True
def embedded_widget(self):
return self._widget
def resizable(self):
return True
def double_clicked(self, parent):
try:
if self._widget.images is not None and not self._widget.popup_visible:
import pyqtgraph
self._widget.popup()
except ImportError:
LOG.debug('pyqtgraph not installed, not popping up')
def set_in_data(self, data: NodeData, port: Port):
if data is not None:
if self._reset:
self._widget.images = data
self._reset = False
else:
self._widget.append(data)
def reset_batches(self):
self._reset = True
def cleanup(self):
self._widget.cleanup()
def get_ufo_model_classes(names=None):
all_names = set(UFO_PLUGIN_MANAGER.get_all_task_names())
# stamp causes a gobject unref warning
blacklist = set(['general-backproject', 'memory-in', 'memory-out', 'opencl', 'read',
'retrieve-phase', 'stamp', 'write'])
all_names = list(all_names - blacklist)
return (get_ufo_model_class(name) for name in names or all_names)
def get_composite_model_classes_from_json(state):
"""
Get composite model classes from their json representation. This is recursive in case a user
creates a composite inside the scene, then adds nodes and creates another composite with the
first one inside and doesn't export explicitly the first one. The order of returned classes is
bottom -> up, i.e. first the classes which have striclty non-composite submodels are returned
and the top level class is last.
"""
classes = []
def go_down(current):
connections = [CompositeConnection(*args) for args in current['connections']]
submodels = []
for (key, model) in current['models'].items():
if 'models' in model['model'] and 'connections' in model['model']:
go_down(current['models'][key]['model'])
# models are tuples (name, state, visible, position)
submodels.append((model['name'],
model['model'],
model['visible'],
model['position']))
classes.append(get_composite_model_class(current['name'], submodels, connections,
links=current.get('links', None)))
go_down(state)
return classes
def get_composite_model_classes():
from xdg import xdg_data_home
composite_lists = []
paths = [pkg_resources.resource_filename(__name__, 'composites'),
os.path.join(xdg_data_home(), 'tofu', 'flows', 'composites')]
for path in paths:
file_names = sorted(glob.glob(os.path.join(path, '*.cm')))
for file_name in file_names:
LOG.debug(f'Loading composite from {file_name}')
try:
with open(file_name, 'r') as f:
state = json.load(f)
composite_lists.append(get_composite_model_classes_from_json(state))
except Exception as e:
LOG.error(e, exc_info=True)
return composite_lists
class UfoModelError(FlowError):
pass
tofu-0.12.0/tofu/flow/propertylinksmodels.py 0000664 0000000 0000000 00000037113 14237137211 0021204 0 ustar 00root root 0000000 0000000 import logging
from PyQt5.QtCore import QDataStream, pyqtSignal
from PyQt5.QtGui import QStandardItemModel, QStandardItem
from tofu.flow.models import PropertyModel, BaseCompositeModel
from tofu.flow.util import MODEL_ROLE, NODE_ROLE, PROPERTY_ROLE
LOG = logging.getLogger(__name__)
def _decode_mime_data(data):
byte_array = data.data('application/x-sourcetreemodelindex')
ds = QDataStream(byte_array)
row = ds.readInt32()
column = ds.readInt32()
internal_id = ds.readUInt64()
return (row, column, internal_id)
def _data_from_tree_index(index):
"""
Traverse parents up to the root and get the root node, model and it's property from *index*,
which must be a property record (leaf in the tree).
"""
prop_name = index.data()
index = index.parent()
model = index.data(role=MODEL_ROLE)
while index.data(role=NODE_ROLE) is None and index.isValid():
index = index.parent()
node = index.data(role=NODE_ROLE)
return (node, model, prop_name)
def _get_string_path(node, model, prop_name):
if isinstance(node.model, BaseCompositeModel):
path = node.model.get_path_from_model(model)
else:
path = [model]
str_path = [model.caption for model in path]
str_path.append(prop_name)
return str_path
class NodeTreeModel(QStandardItemModel):
"""Tree model representing nodes in the scene."""
def add_node(self, node):
item = self._add_model(node.model)
if item:
item.setData(node, role=NODE_ROLE)
def remove_node(self, node):
for j in range(self.rowCount()):
item = self.item(j, 0)
if item and item.data(role=NODE_ROLE) == node:
self.removeRow(j)
break
def clear(self):
"""In PyQt5, clear doesn't emit the rowsAboutToBeRemoved signal and this does effectively
the same.
"""
self.removeRows(0, self.rowCount())
self.removeColumns(0, self.columnCount())
self.rowCount(), self.columnCount()
def set_nodes(self, nodes):
self.clear()
for node in nodes:
self.add_node(node)
def _add_model(self, flow_model, parent=None):
if not parent:
parent = self.invisibleRootItem()
item = None
if (isinstance(flow_model, PropertyModel) or isinstance(flow_model, BaseCompositeModel)):
item = QStandardItem(flow_model.caption)
item.setData(flow_model, role=MODEL_ROLE)
item.setEditable(False)
if isinstance(flow_model, PropertyModel):
for prop in sorted(flow_model):
prop_item = QStandardItem(prop)
prop_item.setEditable(False)
item.appendRow(prop_item)
else:
for submodel_name in sorted(flow_model):
self._add_model(flow_model[submodel_name], parent=item)
if item:
parent.appendRow(item)
return item
class PropertyLinksModel(QStandardItemModel):
"""Links model representing property links between nodes in the scene."""
restored = pyqtSignal()
def __init__(self, node_model):
super().__init__()
self._silent = {}
self._slaves = {}
self._node_model = node_model
self._node_model.rowsAboutToBeRemoved.connect(self.on_node_rows_about_to_be_removed)
def __contains__(self, key):
for column in range(self.columnCount()):
if self.findItems(key, column=column):
return True
return False
def clear(self):
for j in range(self.rowCount()):
for i in range(self.columnCount()):
self.remove_item(self.indexFromItem(self.item(j, i)))
super().clear()
def find_items(self, data_list, roles):
result = []
for j in range(self.rowCount()):
for i in range(self.columnCount()):
item = self.item(j, i)
if item:
success = True
for (data, role) in zip(data_list, roles):
if item.data(role=role) != data:
success = False
break
if success:
result.append(item)
return result
def get_model_links(self, models):
"""
Get links between *models*. Return dict {row index: [str_path, ...]}, where *str_path* is
the path from the topmost model (in case of composites along the way) to the property name.
"""
items = {}
for model in models:
for item in self.find_items([model], [MODEL_ROLE]):
str_path = item.text().split('->')
if item.row() not in items:
items[item.row()] = [str_path]
else:
items[item.row()].append(str_path)
return items
def get_root_model(self, model):
root_model = None
items = self.find_items([model], [MODEL_ROLE])
if items:
root_model = items[0].data(role=MODEL_ROLE)
else:
for (silent_model, prop_name) in self._silent:
if silent_model == model:
root_model = self._silent[(silent_model, prop_name)][0]
return root_model
def get_model_properties(self, model):
items = self.find_items([model], [MODEL_ROLE])
return [item.data(role=PROPERTY_ROLE) for item in items]
def add_item(self, node, model, prop_name, row, column, insert=False):
"""
Add item where *node* is the root node (can be composite), *model* is the leaf model
(there can be composites above if the leaf is nested) and *prop_name* is the property name.
*row* and *column* determine the table cell to which to add the item or replace an old item
with the new one. If *insert* is True, insert a new row at *row*.
"""
str_path = '->'.join(_get_string_path(node, model, prop_name))
if str_path in self:
raise ValueError(f'{str_path} already inside')
item = QStandardItem(str_path)
item.setData(model, role=MODEL_ROLE)
item.setData(prop_name, role=PROPERTY_ROLE)
item.setData(node, role=NODE_ROLE)
item.setEditable(False)
if row == -1:
row = self.rowCount()
if column == -1:
# +1 to find an empty cell even if the row is full
for i in range(self.columnCount() + 1):
if self.item(row, i) is None:
column = i
break
LOG.debug(f'Add item {node.model.caption}({item.data(role=MODEL_ROLE)}):'
f'{item.data(role=PROPERTY_ROLE)} at ({row}, {column})')
if insert:
self.insertRow(row, item)
else:
self.setItem(row, column, item)
# In case the composite is being edit in a subwindow, connect the slave nodes from the
# subsecene
if isinstance(node.model, BaseCompositeModel):
node.model.add_slave_links()
model.property_changed.connect(self.on_property_changed)
def remove_item(self, index):
flow_model = index.data(role=MODEL_ROLE)
if not flow_model:
# Empty cell
return
property_name = index.data(role=PROPERTY_ROLE)
flow_model.property_changed.disconnect(self.on_property_changed)
self.setItem(index.row(), index.column(), None)
# Remove all associated slaves
root_key = (flow_model, property_name)
if root_key in self._slaves:
for slave_key in tuple(self._slaves[root_key]):
self.remove_silent(*slave_key)
def add_silent(self, model, property_name, root, root_property_name):
key = (model, property_name)
if key in self._silent:
return
model.property_changed.connect(self.on_property_changed)
root_key = (root, root_property_name)
if not self.find_items(root_key, (MODEL_ROLE, PROPERTY_ROLE)):
raise ValueError(f'{model} not in property links')
self._silent[key] = root_key
if root_key not in self._slaves:
self._slaves[root_key] = [key]
else:
self._slaves[root_key].append(key)
LOG.debug(f'Slave {root}->{root_property_name} -> {model}->{property_name} added')
def remove_silent(self, model, property_name):
key = (model, property_name)
if key not in self._silent:
# Already removed, e.g. by deleting an item by del key while some composite windows were
# still opened
return
model.property_changed.disconnect(self.on_property_changed)
root_key = self._silent[key]
index = self._slaves[root_key].index(key)
del self._slaves[root_key][index]
if not self._slaves[root_key]:
del self._slaves[root_key]
del self._silent[key]
LOG.debug(f'Slave {model}->{property_name} removed')
def replace_item(self, node, new_model, old_model):
for j in range(self.rowCount()):
for i in range(self.columnCount()):
item = self.item(j, i)
if item and item.data(role=MODEL_ROLE) == old_model:
# Don't break, replace all properties of *old_model*
prop_name = item.data(role=PROPERTY_ROLE)
slaves = tuple(self._slaves.get((old_model, prop_name), []))
self.remove_item(self.indexFromItem(item))
self.add_item(node, new_model, prop_name, j, i)
for (slave_model, slave_property_name) in slaves:
self.add_silent(slave_model, slave_property_name, new_model, prop_name)
def on_node_rows_about_to_be_removed(self, parent, first, last):
for k in range(first, last + 1):
node = self._node_model.item(k, 0).data(role=NODE_ROLE)
for j in range(self.rowCount()):
for i in range(self.columnCount()):
item = self.item(j, i)
if item and item.data(role=NODE_ROLE) == node:
self.remove_item(self.indexFromItem(item))
self.compact()
def canDropMimeData(self, data, action, row, column, parent):
can_drop = False
if data.hasFormat('application/x-sourcetreemodelindex'):
src_row, src_column, src_internal_id = _decode_mime_data(data)
src_model_index = self._node_model.createIndex(src_row, src_column, src_internal_id)
# src_model_index is the property, it's parent is the model
node, flow_model, property_name = _data_from_tree_index(src_model_index)
str_path = '->'.join(_get_string_path(node, flow_model, property_name))
can_drop = str_path not in self
if parent.isValid():
# Parent itself can be an empty cell, so use the first column which is for sure
# occupied since the parent is valid (row exists and we are not between rows)
first_item = self.item(parent.row(), 0)
parent_model = first_item.data(role=MODEL_ROLE)
parent_property_name = first_item.data(role=PROPERTY_ROLE)
if not type(flow_model[property_name]) is type(parent_model[parent_property_name]):
# Data can be dropped only if the types of properties match
can_drop = False
return can_drop
def dropMimeData(self, data, action, row, column, parent):
src_row, src_column, src_internal_id = _decode_mime_data(data)
src_model_index = self._node_model.createIndex(src_row, src_column, src_internal_id)
node, flow_model, property_name = _data_from_tree_index(src_model_index)
if parent.isValid():
row = parent.row()
insert = False
else:
insert = True
# drops never replace items and column=-1 means "find an empty cell"
self.add_item(node, flow_model, property_name, row, -1, insert=insert)
return True
def save(self):
state = []
for j in range(self.rowCount()):
row_state = []
for i in range(self.columnCount()):
item = self.item(j, i)
if not item:
continue
node = item.data(role=NODE_ROLE)
model = item.data(role=MODEL_ROLE)
prop_name = item.data(role=PROPERTY_ROLE)
str_path = _get_string_path(node, model, prop_name)
row_state.append([node.id, str_path])
state.append(row_state)
return state
def restore(self, state, nodes):
self.clear()
for (j, row) in enumerate(state):
for (i, (node_id, path)) in enumerate(row):
node = nodes[node_id]
# Last path entry is the property name
if isinstance(node.model, BaseCompositeModel):
flow_model = node.model.get_model_from_path(path[1:-1])
else:
flow_model = node.model
self.add_item(node, flow_model, path[-1], j, i)
self.restored.emit()
def compact(self):
# Shift rows to the left
for j in range(self.rowCount()):
filled = []
for i in range(self.columnCount()):
if self.item(j, i):
filled.append(self.takeItem(j, i))
for (i, item) in enumerate(filled):
self.setItem(j, i, item)
# Check empty rows
for j in range(self.rowCount())[::-1]:
is_empty = True
for i in range(self.columnCount()):
if self.item(j, i):
is_empty = False
if is_empty:
self.removeRow(j)
# Check empty columns
for i in range(self.columnCount())[::-1]:
is_empty = True
for j in range(self.rowCount()):
if self.item(j, i):
is_empty = False
if is_empty:
self.removeColumn(i)
def on_property_changed(self, sig_model, sig_property_name, value):
LOG.debug(f'on_property_changed: {sig_model}, {sig_model.caption}, '
f'{sig_property_name}, {value}')
sig_key = (sig_model, sig_property_name)
if sig_key in self._silent:
# pyqtSignal came from a composite subwindow, get root model from the silent slave
root_key = self._silent[sig_key]
root_key[0][root_key[1]] = value
else:
root_key = (sig_model, sig_property_name)
row = -1
for j in range(self.rowCount()):
for i in range(self.columnCount()):
item = self.item(j, i)
if (item and item.data(role=MODEL_ROLE) == root_key[0]
and item.data(role=PROPERTY_ROLE) == root_key[1]):
row = j
break
if row != -1:
break
for i in range(self.columnCount()):
item = self.item(row, i)
if item:
model = item.data(role=MODEL_ROLE)
property_name = item.data(role=PROPERTY_ROLE)
if root_key != (model, property_name):
model[property_name] = value
# Notify all slaves
key = (model, property_name)
if key in self._slaves:
for (slave_model, slave_property_name) in self._slaves[key]:
if (slave_model, slave_property_name) != (sig_model, sig_property_name):
slave_model[slave_property_name] = value
tofu-0.12.0/tofu/flow/propertylinkswidget.py 0000664 0000000 0000000 00000007023 14237137211 0021201 0 ustar 00root root 0000000 0000000 from PyQt5.QtCore import QMimeData, Qt, QDataStream, QByteArray, QIODevice, QModelIndex
from PyQt5.QtGui import QDrag
from PyQt5.QtWidgets import QAbstractItemView, QLabel, QTableView, QTreeView, QVBoxLayout, QWidget
def _encode_mime_data(index: QModelIndex):
"""Encode item in *index* into :class:`QMimeData`."""
mime_data = QMimeData()
data = QByteArray()
stream = QDataStream(data, QIODevice.WriteOnly)
try:
stream.writeInt32(index.row())
stream.writeInt32(index.column())
stream.writeUInt64(index.internalId())
finally:
stream.device().close()
mime_data.setData("application/x-sourcetreemodelindex", data)
return mime_data
class PropertyLinksView(QTableView):
"""Table view for displaying node property links."""
def keyPressEvent(self, event):
if event.key() == Qt.Key_Delete:
model = self.model()
for index in self.selectedIndexes():
model.remove_item(index)
model.compact()
class NodesView(QTreeView):
"""Tree view displaying nodes in the scene."""
def get_drag_index(self):
selected = self.selectedIndexes()
if not selected:
return
index = selected[0]
if index.child(0, 0).row() != -1:
return
return index
def mouseMoveEvent(self, event):
"""All that a mouse *event* can do is start a drag and drop operation."""
index = self.get_drag_index()
if not index:
return
drag = QDrag(self)
mime_data = _encode_mime_data(index)
drag.setMimeData(mime_data)
drag.exec_(Qt.CopyAction)
return True
class PropertyLinks(QWidget):
"""Widget displaying nodes in the scene and their property links in one window."""
def __init__(self, node_model, table_model, parent=None):
super().__init__(parent=parent, flags=Qt.Window)
self.setWindowTitle('Property Links')
self.resize(600, 800)
self._treeview = NodesView()
self._treeview.setHeaderHidden(True)
self._treeview.setAlternatingRowColors(True)
self._treeview.setDragEnabled(True)
self._treeview.setAcceptDrops(False)
self._treeview.setModel(node_model)
node_model.itemChanged.connect(self.on_node_model_changed)
self._table_view = PropertyLinksView()
self._table_view.setDragDropOverwriteMode(False)
self._table_view.setDragDropMode(QAbstractItemView.DropOnly)
table_model.itemChanged.connect(self.on_table_model_changed)
table_model.rowsInserted.connect(self.on_table_model_rows_inserted)
table_model.restored.connect(self.on_table_model_restored)
self._table_view.setModel(table_model)
main_layout = QVBoxLayout()
main_layout.addWidget(self._treeview)
main_layout.addWidget(QLabel('Drag properties from above to the area below'))
main_layout.addWidget(self._table_view)
self.setLayout(main_layout)
def show(self):
self._table_view.resizeColumnsToContents()
self._treeview.sortByColumn(0, Qt.AscendingOrder)
super().show()
def on_table_model_changed(self, item):
self._table_view.resizeColumnToContents(item.column())
def on_table_model_rows_inserted(self, index, start, stop):
self._table_view.resizeColumnToContents(0)
def on_table_model_restored(self):
self._table_view.resizeColumnsToContents()
def on_node_model_changed(self, item):
self._treeview.sortByColumn(0, Qt.AscendingOrder)
tofu-0.12.0/tofu/flow/runslider.py 0000664 0000000 0000000 00000016053 14237137211 0017062 0 ustar 00root root 0000000 0000000 from functools import partial
from PyQt5.QtCore import Qt, pyqtSignal, QTimer
from PyQt5 import QtGui
from PyQt5.QtWidgets import QGridLayout, QLineEdit, QWidget, QSlider
from tofu.flow.models import IntQLineEditViewItem, RangeQLineEditViewItem, UfoIntValidator
from tofu.flow.util import FlowError
class RunSlider(QWidget):
value_changed = pyqtSignal(float)
def __init__(self, parent=None):
super().__init__(parent=parent, flags=Qt.Window)
self.setWindowFlag(Qt.WindowStaysOnTopHint)
self.setMaximumHeight(20)
self.setMinimumWidth(600)
self.min_edit = QLineEdit()
self.min_edit.setToolTip('Minimum')
self.min_edit.setMaximumWidth(80)
self.min_edit.editingFinished.connect(self.on_min_edit_editing_finished)
self.current_edit = QLineEdit()
self.current_edit.setToolTip('Current value')
self.current_edit.editingFinished.connect(self.on_current_edit_editing_finished)
self.max_edit = QLineEdit()
self.max_edit.setToolTip('Maximum')
self.max_edit.setMaximumWidth(80)
self.max_edit.editingFinished.connect(self.on_max_edit_editing_finished)
self.slider = QSlider(orientation=Qt.Horizontal)
self.slider.setMinimum(0)
self.slider.setMaximum(100)
self.slider.valueChanged.connect(self.on_slider_value_changed)
main_layout = QGridLayout()
main_layout.addWidget(self.current_edit, 0, 0, 1, 3, Qt.AlignHCenter)
main_layout.addWidget(self.min_edit, 1, 0)
main_layout.addWidget(self.slider, 1, 1)
main_layout.addWidget(self.max_edit, 1, 2)
self.setLayout(main_layout)
self.view_item = None
self.real_minimum = 0
self.real_maximum = 100
self.real_span = 100
self.type = None
self._last_value = None
self.setEnabled(False)
def _update_range(self, current=None):
self.real_span = self.real_maximum - self.real_minimum
if current is not None:
self.slider.blockSignals(True)
self.slider.setValue(int(round((current - self.real_minimum) / self.real_span * 100)))
self.slider.blockSignals(False)
def get_real_value(self):
# First convert possible exponents to float (in case UFO has huge defaults set)
return self.type(float(self.current_edit.text()))
def set_widget_value(self):
value = self.get_real_value()
self._last_value = value
if isinstance(self.view_item, RangeQLineEditViewItem):
value = [value]
self.view_item.set(value)
# Notify linked widgets
self.view_item.property_changed.emit(self.view_item)
def set_current_validator(self):
if self.type == int:
validator = UfoIntValidator(self.real_minimum, self.real_maximum)
else:
validator = QtGui.QDoubleValidator(self.real_minimum, self.real_maximum, 1000)
self.current_edit.setValidator(validator)
def setup(self, view_item):
if self.view_item == view_item:
return False
current = view_item.get()
if isinstance(view_item, RangeQLineEditViewItem):
if len(current) > 1:
return False
self.type = float
current = current[0]
d_current = 0.1 * abs(current) if current else 100
self.real_minimum = current - d_current
self.real_maximum = current + d_current
else:
self.type = int if isinstance(view_item, IntQLineEditViewItem) else float
self.real_minimum = view_item.widget.validator().bottom()
self.real_maximum = view_item.widget.validator().top()
self.view_item = view_item
self._update_range(current=current)
_set_number(self.min_edit, self.real_minimum)
_set_number(self.max_edit, self.real_maximum)
_set_number(self.current_edit, current)
self._last_value = current
self.setEnabled(True)
self.set_current_validator()
return True
def reset(self):
self.real_minimum = 0
self.real_maximum = 100
self.real_span = 100
self._last_value = None
self.type = None
self.min_edit.setText('')
self.max_edit.setText('')
self.current_edit.setText('')
self.setWindowTitle('')
self.view_item = None
self.setEnabled(False)
def on_slider_value_changed(self, value):
def delayed_update(init_value):
current_value = self.slider.value()
if init_value == current_value:
self.set_widget_value()
self.value_changed.emit(real_value)
if self.view_item:
real_value = self.slider.value() / 100 * self.real_span + self.real_minimum
self.current_edit.setText('{:g}'.format(self.type(real_value)))
func = partial(delayed_update, value)
QTimer.singleShot(100, func)
def on_current_edit_editing_finished(self):
if not self.view_item:
return
try:
value = self.type(self.current_edit.text())
except ValueError:
raise RunSliderError('Not a number')
if value == self._last_value:
# Nothing new, do not emit value_changed signal in case the app is closing
return
self.slider.blockSignals(True)
self.slider.setValue(int(round((value - self.real_minimum) / self.real_span * 100)))
self.slider.blockSignals(False)
self.set_widget_value()
self.value_changed.emit(value)
def on_min_edit_editing_finished(self):
if not self.view_item:
return
try:
value = self.type(self.min_edit.text())
except ValueError:
raise RunSliderError('Not a number')
if value >= self.real_maximum:
raise RunSliderError('Minimum must be smaller than maximum')
current = self.get_real_value()
self.real_minimum = value
if current < self.real_minimum:
current = self.real_minimum
self.current_edit.setText('{:g}'.format(current))
self.set_widget_value()
self.value_changed.emit(current)
self._update_range(current=current)
self.set_current_validator()
def on_max_edit_editing_finished(self):
if not self.view_item:
return
try:
value = self.type(self.max_edit.text())
except ValueError:
raise RunSliderError('Not a number')
if value <= self.real_minimum:
raise RunSliderError('Maximum must be greater than minimum')
current = self.get_real_value()
self.real_maximum = value
if current > self.real_maximum:
current = self.real_maximum
self.current_edit.setText('{:g}'.format(current))
self.set_widget_value()
self.value_changed.emit(current)
self._update_range(current=current)
self.set_current_validator()
def _set_number(edit, number):
edit.setText('{:g}'.format(number))
class RunSliderError(FlowError):
pass
tofu-0.12.0/tofu/flow/scene.py 0000664 0000000 0000000 00000041576 14237137211 0016160 0 ustar 00root root 0000000 0000000 import logging
import numpy as np
import networkx as nx
from PyQt5.QtCore import pyqtSignal, QObject
from PyQt5.QtWidgets import QInputDialog
from qtpynodeeditor import FlowScene, NodeDataModel, PortType, opposite_port
from tofu.flow.models import (BaseCompositeModel, ImageViewerModel, PropertyModel,
UFO_DATA_TYPE, get_composite_model_class,
get_composite_model_classes_from_json)
from tofu.flow.util import CompositeConnection, FlowError, saved_kwargs
from tofu.flow.propertylinksmodels import PropertyLinksModel, NodeTreeModel
LOG = logging.getLogger(__name__)
class UfoScene(FlowScene):
nodes_duplicated = pyqtSignal(list, dict)
# view item, its name and model name
item_focus_in = pyqtSignal(QObject, str, str, NodeDataModel)
def __init__(self, registry=None, style=None, parent=None,
allow_node_creation=True, allow_node_deletion=True):
super().__init__(registry=registry,
style=style,
parent=parent,
allow_node_creation=allow_node_creation,
allow_node_deletion=allow_node_deletion)
self._composite_nodes = {}
self._selected_nodes_on_disabled = []
self.node_model = NodeTreeModel()
self.node_model.setColumnCount(1)
self.property_links_model = PropertyLinksModel(self.node_model)
self.style_collection.node.opacity = 1
self.style_collection.connection.use_data_defined_colors = True
self.node_double_clicked.connect(self.on_node_double_clicked)
def __getstate__(self):
state = super().__getstate__()
state['property-links'] = self.property_links_model.save()
return state
def __setstate__(self, doc):
for node in doc['nodes']:
model = node['model']
if 'models' in model and 'connections' in model:
# First register the composite model
models = get_composite_model_classes_from_json(model)
for model in models:
self.registry.register_model(model, category='Composite',
registry=self.registry)
# Restore the scene
super().__setstate__(doc)
# and the property link models and widgets
if 'property-links' in doc:
self.node_model.set_nodes(self.nodes.values())
self.property_links_model.restore(doc['property-links'], self.nodes)
def create_node(self, data_model, restore_links=True):
"""Overrides :class:`FlowScene` in order to create a node with *data_model* with a unique
caption.
"""
LOG.debug(f'Create node with model {data_model}')
node = super().create_node(data_model)
self._setup_new_node(node)
if restore_links and isinstance(node.model, BaseCompositeModel):
node.model.restore_links(node)
return node
def restore_node(self, node_json):
LOG.debug(f"Restore node with model {node_json['model']['name']}")
with saved_kwargs(self.registry, node_json['model']):
node = super().restore_node(node_json)
self._setup_new_node(node)
return node
def on_item_focus_in(self, view_item, prop_name, caption, model):
self.item_focus_in.emit(view_item, prop_name, caption, model)
def _setup_new_node(self, node):
self._set_unique_caption(node)
self.node_model.add_node(node)
if isinstance(node.model, BaseCompositeModel):
node.model.property_links_model = self.property_links_model
node.model.item_focus_in.connect(self.on_item_focus_in)
def _set_unique_caption(self, new_node):
caption = new_node.model.caption
captions = [node.model.caption for node in self.nodes.values() if node != new_node]
if caption in captions:
fmt = new_node.model.base_caption + ' {}'
i = 2
while fmt.format(i) in captions:
i += 1
caption = fmt.format(i)
new_node.model.caption = caption
def remove_node(self, node):
if hasattr(node.model, 'cleanup'):
node.model.cleanup()
if (isinstance(node.model, BaseCompositeModel) and node.model.name
in self._composite_nodes):
del self._composite_nodes[node.model.name]
self.node_model.remove_node(node)
super().remove_node(node)
def is_selected_one_composite(self):
result = False
nodes = self.selected_nodes()
if len(nodes) == 1:
result = isinstance(nodes[0].model, BaseCompositeModel)
return result
def skip_nodes(self):
selected_nodes = self.selected_nodes()
# First check if the selected nodes may be skipped
for node in selected_nodes:
if (node.model.num_ports[PortType.input] != 1
or node.model.num_ports[PortType.output] != 1):
raise FlowError('Only nodes with one input and one output can be skipped')
ports = list(node.state.ports)
if ports[0].data_type != UFO_DATA_TYPE or ports[1].data_type != UFO_DATA_TYPE:
raise FlowError('Only tasks with UFO input and output can be skipped')
# And only if all is fine, then skip them
for node in selected_nodes:
node.model.skip = not node.model.skip
opacity = 0.5 if node.model.skip else 1
node.state.input_connections[0].graphics_object.setOpacity(opacity)
node.state.output_connections[0].graphics_object.setOpacity(opacity)
node.graphics_object.setOpacity(opacity)
def auto_fill(self):
for node in self.nodes.values():
if isinstance(node.model, BaseCompositeModel):
paths = node.model.get_leaf_paths()
else:
paths = [[node.model]]
for path in paths:
model = path[-1]
if isinstance(model, PropertyModel):
model.auto_fill()
def copy_nodes(self):
new_nodes = {}
selected_nodes = self.selected_nodes()
# Create nodes
for node in selected_nodes:
new_node = self.create_node(node.model)
new_nodes[node] = new_node
values = node.model.save()
new_node.model.restore(values, restore_caption=False)
# Create connections
for node, new_node in new_nodes.items():
for connection in self.connections:
port = connection.ports[0]
in_index = port.index
out_index = connection.ports[1].index
if port.node == node:
other_node = connection.ports[1].node
if other_node in new_nodes:
# Other node has been also selected
self.create_connection_by_index(new_node,
in_index,
new_nodes[other_node],
out_index,
None)
self.nodes_duplicated.emit(selected_nodes, new_nodes)
def create_composite(self):
composite_name, ok = QInputDialog.getText(None, 'Create Composite Node', 'Name:')
if not ok:
return
if composite_name in self.registry.registered_model_creators():
raise FlowError(f'Composite node with name "{composite_name}" has already '
'been registered')
self._composite_nodes[composite_name] = {}
connection_replacements = []
models = []
connections = []
selected_nodes = self.selected_nodes()
for node in selected_nodes:
unique_name = node.model.caption
models.append((node.model.name,
node.model.save(),
True,
node.__getstate__()['position']))
self._composite_nodes[composite_name][unique_name] = node.__getstate__()
# Connections
assigned_ports = []
x = []
y = []
for node in selected_nodes:
x.append(node.position.x())
y.append(node.position.y())
for port_type in ['input', 'output']:
for index, port in node[port_type].items():
if port.connections:
# We allow only one connection
conn = port.connections[0]
other_port = conn.ports[0] if conn.ports[1] == port else conn.ports[1]
other = conn.get_node(opposite_port(port_type))
if (other in selected_nodes and port not in assigned_ports
and other_port not in assigned_ports):
# Connection reaches to a node outside selection
if port_type == PortType.input:
to_node_name = node.model.caption
to_node_index = index
from_node_name = other.model.caption
from_node_index = other_port.index
else:
to_node_name = other.model.caption
to_node_index = other_port.index
from_node_name = node.model.caption
from_node_index = index
conn = CompositeConnection(from_node_name, from_node_index,
to_node_name, to_node_index)
connections.append(conn)
assigned_ports.append(port)
if other not in selected_nodes:
inside = (node.model.caption, port_type, index)
connection_replacements.append((other_port, inside))
# Get links which will be internal to the newly created model
node_models = []
for selected_node in self.selected_nodes():
if isinstance(selected_node.model, BaseCompositeModel):
paths = selected_node.model.get_leaf_paths()
else:
paths = [[selected_node.model]]
node_models += [path[-1] for path in paths]
internal_links = list(self.property_links_model.get_model_links(node_models).values())
composite = get_composite_model_class(composite_name,
models,
connections,
links=internal_links)
self.registry.register_model(composite,
category='Composite',
registry=self.registry)
node = self.create_node(composite, restore_links=False)
for selected_node in selected_nodes:
if isinstance(selected_node.model, BaseCompositeModel):
# Get all leaf PropertyModel instances
paths = selected_node.model.get_leaf_paths()
else:
paths = [[selected_node.model]]
# In case selected node is composite, replace all leaf node links
for path in paths:
new_model = node.model.get_model_from_path([model.caption for model in path])
self.property_links_model.replace_item(node, new_model, path[-1])
self.remove_node(selected_node)
for outside_port, inside in connection_replacements:
port_type, index = node.model.get_outside_port(*inside)
self.create_connection(outside_port, node[port_type][index], check_cycles=False)
# Put the new composite node to the average of x and y position of the selected nodes
node.position = (np.mean(x), np.mean(y))
node.graphics_object.setSelected(True)
return node
def on_node_double_clicked(self, node):
views = self.views()
if views:
node.model.double_clicked(views[0])
def expand_composite(self, node):
name = node.model.name
original_nodes = self._composite_nodes.get(name, None)
return node.model.expand_into_scene(self, node, original_nodes=original_nodes)
def is_fully_connected(self):
"""Are all the ports in all nodes connected?"""
def are_ports_connected(node, port_type):
for port in node[port_type].values():
if not port.connections:
return False
return True
for node in self.nodes.values():
if not are_ports_connected(node, 'input'):
return False
if not are_ports_connected(node, 'output'):
return False
return True
def get_simple_node_graphs(self):
"""
Get a graph from the scene without composite nodes which can be directly used byt the
execution.
"""
def get_composite(graph):
"""Get first found composite model."""
for model in graph.nodes:
if isinstance(model, BaseCompositeModel):
return model
def replace_edge(graph, composite, edges, port_type):
"""Replace interface edges (going in or out from the composite model)."""
for edge in edges:
ports = graph.edges[edge]
other = edge[0] if port_type == PortType.input else edge[1]
model, index = composite.get_model_and_port_index(port_type, ports[port_type])
if model not in graph:
graph.add_node(model)
if port_type == PortType.input:
source = other
dest = model
input_port = index
output_port = ports[PortType.output]
else:
source = model
dest = other
input_port = ports[PortType.input]
output_port = index
LOG.debug(f'Adding edge {source.name}@{output_port} -> {dest.name}@{input_port}')
graph.add_edge(source, dest, input=input_port, output=output_port)
def replace_composite(graph, composite):
composite.expand_into_graph(graph)
edges = graph.in_edges(composite, keys=True)
replace_edge(graph, composite, edges, PortType.input)
edges = graph.out_edges(composite, keys=True)
replace_edge(graph, composite, edges, PortType.output)
graph.remove_node(composite)
# Initial graph with composite nodes. We need a multigraph because composite nodes may have
# many outputs which can lead to a same destination node.
graph = nx.MultiDiGraph()
for node in self.nodes.values():
if not node.model.skip:
graph.add_node(node.model)
for conn in self.connections:
p_dest, p_source = conn.ports
if p_dest.node.model.skip:
LOG.debug(f'Skiping connection {p_source.node.model.name} -> '
f'{p_dest.node.model.name}')
continue
while p_source.node.model.skip:
LOG.debug(f'Skiping connection {p_source.node.model.name} -> '
f'{p_dest.node.model.name}')
previous_conn = p_source.node.state.input_connections[0]
previous_node = previous_conn.output_node
p_source = list(previous_node.state.output_ports)[0]
graph.add_edge(p_source.node.model, p_dest.node.model, input=p_dest.index,
output=p_source.index)
# Expand composite nodes until there are only simple ones left
model = get_composite(graph)
while model:
LOG.debug(f'Replacing composite {model.name}')
replace_composite(graph, model)
model = get_composite(graph)
components = nx.weakly_connected_components(graph)
return [nx.subgraph(graph, component) for component in components]
def set_enabled(self, enabled):
selected_nodes = self.selected_nodes()
self.allow_node_creation = enabled
self.allow_node_deletion = enabled
for node in self.nodes.values():
if not isinstance(node.model, ImageViewerModel):
node.graphics_object.setEnabled(enabled)
if enabled:
if node in self._selected_nodes_on_disabled:
node.graphics_object.setSelected(True)
else:
if node in selected_nodes:
self._selected_nodes_on_disabled.append(node)
for conn in self.connections:
conn._graphics_object.setEnabled(enabled)
if enabled:
self._selected_nodes_on_disabled = []
tofu-0.12.0/tofu/flow/util.py 0000664 0000000 0000000 00000004432 14237137211 0016026 0 ustar 00root root 0000000 0000000 import contextlib
import json
import pkg_resources
from PyQt5.QtCore import Qt
from qtpynodeeditor import PortType
MODEL_ROLE = Qt.UserRole + 1
PROPERTY_ROLE = MODEL_ROLE + 1
NODE_ROLE = PROPERTY_ROLE + 1
with open(pkg_resources.resource_filename(__name__, 'config.json')) as f:
ENTRIES = json.load(f)
def get_config_key(*keys, default=None):
current = ENTRIES.get(keys[0], default)
if current != default and len(keys) > 1:
for key in keys[1:]:
current = current.get(key, default)
if current == default:
break
return current
@contextlib.contextmanager
def saved_kwargs(registry, state):
"""
Tell the registry to use the number of saved inputs for model creation but only for one model
creation, i.e. reset the context afterward.
"""
if 'num-inputs' in state:
kwargs = registry.registered_model_creators()[state['name']][1]
kwargs['num_inputs'] = state['num-inputs']
try:
yield
finally:
if 'num-inputs' in state:
del kwargs['num_inputs']
class CompositeConnection:
def __init__(self, from_unique_name, from_port_index, to_unique_name, to_port_index):
if from_unique_name == to_unique_name:
raise ValueError('from_unique_name and to_unique_name must be different')
self.from_unique_name = from_unique_name
self.from_port_index = from_port_index
self.to_unique_name = to_unique_name
self.to_port_index = to_port_index
def contains(self, unique_name, port_type, port_index):
is_from = is_to = False
if port_type == PortType.output:
is_from = (unique_name == self.from_unique_name and port_index == self.from_port_index)
else:
is_to = (unique_name == self.to_unique_name and port_index == self.to_port_index)
return is_from or is_to
def save(self):
return [self.from_unique_name, self.from_port_index,
self.to_unique_name, self.to_port_index]
def __str__(self):
return repr(self)
def __repr__(self):
fmt = 'Connection({}@{} -> {}@{})'
return fmt.format(self.from_unique_name, self.from_port_index,
self.to_unique_name, self.to_port_index)
class FlowError(Exception):
pass
tofu-0.12.0/tofu/flow/viewer.py 0000664 0000000 0000000 00000045100 14237137211 0016347 0 ustar 00root root 0000000 0000000 import logging
import numpy as np
import os
from PyQt5 import QtGui
from PyQt5.QtCore import Qt
from PyQt5.QtWidgets import QFileDialog, QGridLayout, QLabel, QLineEdit, QMenu, QWidget, QSlider
from tofu.flow.util import FlowError
LOG = logging.getLogger(__name__)
class ScreenImage:
"""On-screen image representation."""
def __init__(self, image=None):
self._black_point = None
self._white_point = None
self.minimum = None
self.maximum = None
self.image = image
@property
def image(self):
return self._image
@image.setter
def image(self, image):
"""
Keep the minimum, maximum, black and white points as they are so that images don't
flicker when going through a sequence.
"""
self._image = image
if self._image is not None:
self._image = image.astype(np.float32)
if self.minimum is None:
self.minimum = np.nanmin(self._image)
if self.maximum is None:
self.maximum = np.nanmax(self._image)
if self.black_point is None:
self.black_point = self.minimum
if self.white_point is None:
self.white_point = self.maximum
@property
def white_point(self):
return self._white_point
@white_point.setter
def white_point(self, value):
if self.black_point is not None and value < self.black_point:
raise ImageViewingError('White point cannot be smaller than black point')
self._white_point = value
@property
def black_point(self):
return self._black_point
@black_point.setter
def black_point(self, value):
if self.white_point is not None and value > self.white_point:
raise ImageViewingError('Black point cannot be greater than white point')
self._black_point = value
def reset(self):
"""Reset black and white points."""
if self._image is not None:
self.minimum = np.nanmin(self._image)
self.maximum = np.nanmax(self._image)
self._black_point = self.minimum
self._white_point = self.maximum
def auto_levels(self, percentile=0.1):
"""
Compute cumulative histogram normalized to [0, 100] and truncate gray values which fall
below *percentile* or above 100 - *percentile*.
"""
hist, bins = np.histogram(self._image, bins=256)
cumsum = np.cumsum(hist) / float(np.sum(hist)) * 100
valid = bins[np.where((cumsum > percentile) & (cumsum < 100 - percentile))]
if len(valid):
self.black_point = valid[0]
self.white_point = valid[-1]
else:
self.black_point = self.white_point = self._image[0, 0]
def set_black_point_normalized(self, value):
"""Set black point according to *value*, where value is from interval [0, 255]."""
native = self.convert_normalized_value_to_native(value)
if native > self.white_point:
raise ImageViewingError('Black point cannot be greater than white point')
self.black_point = native
def set_white_point_normalized(self, value):
"""Set white point according to *value*, where value is from interval [0, 255]."""
native = self.convert_normalized_value_to_native(value)
if native < self.black_point:
raise ImageViewingError('White point cannot be smaller than white point')
self.white_point = native
def convert_normalized_value_to_native(self, value):
"""Convert *value* from interval [0, 255] to the gray value in the image."""
if value < 0 or value > 255:
raise ImageViewingError('Normalized value must be in interval [0, 255]')
span = self.maximum - self.minimum
return value / 255 * span + self.minimum
def convert_native_value_to_normalized(self, value):
"""Convert gray value in the image to a normalized value in interval [0, 255]."""
if value < self.minimum or value > self.maximum:
raise ImageViewingError(f'Value must be in interval [{self.minimum}, {self.maximum}]')
span = self.maximum - self.minimum
return (value - self.minimum) / span * 255 if span > 0 else 0
def get_pixmap(self, downsampling=1):
"""Get :class:`QPixmap` for display."""
if self.black_point is None or self.white_point is None:
raise ImageViewingError('Image has not been set')
image = self.image[::downsampling, ::downsampling] - self.black_point
if self.white_point - self.black_point > 0:
image = np.clip(image * 255 / (self.white_point - self.black_point), 0, 255)
image = image.astype(np.uint8)
qim = QtGui.QImage(image, image.shape[1], image.shape[0],
image[0].nbytes, QtGui.QImage.Format.Format_Grayscale8)
return QtGui.QPixmap.fromImage(qim)
class ImageLabel(QLabel):
"""QLabel holding the image data."""
def __init__(self, screen_image=None, parent=None):
super().__init__(parent=parent)
self.screen_image = screen_image
def updateImage(self):
if self.screen_image and self.screen_image.image is not None:
hd = self.screen_image.image.shape[1] // self.width()
vd = self.screen_image.image.shape[0] // self.height()
downsampling = max(min(hd, vd), 1)
pixmap = self.screen_image.get_pixmap(downsampling=downsampling)
self.setPixmap(pixmap.scaled(self.width(), self.height(), Qt.KeepAspectRatio))
def resizeEvent(self, event):
self.updateImage()
class ImageViewer(QWidget):
edit_height = 16
edit_width = 100
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._images = None
self._last_save_dir = '.'
# Pyqtgraph popped up window
self._pg_window = None
self.screen_image = ScreenImage()
self.new_image_auto_levels = True
self.label = ImageLabel(self.screen_image)
self.label.setAlignment(Qt.AlignVCenter | Qt.AlignCenter)
self.slider_edit = QLineEdit()
self.slider_edit.setFixedSize(self.edit_width, self.edit_height)
self.slider_edit.returnPressed.connect(self.on_slider_edit_return_pressed)
self.slider = QSlider(Qt.Horizontal)
validator = QtGui.QIntValidator(0, self.slider.maximum())
self.slider_edit.setValidator(validator)
self.slider.valueChanged.connect(self.on_slider_value_changed)
self.min_slider = QSlider(Qt.Horizontal)
self.max_slider = QSlider(Qt.Horizontal)
self.min_slider_edit = QLineEdit()
self.min_slider_edit.setFixedSize(self.edit_width, self.edit_height)
self.max_slider_edit = QLineEdit()
self.max_slider_edit.setFixedSize(self.edit_width, self.edit_height)
self.min_slider.setMinimum(0)
self.max_slider.setMinimum(0)
self.min_slider.setMaximum(255)
self.max_slider.setMaximum(255)
self.max_slider.setValue(255)
self.min_slider.valueChanged.connect(self.on_min_slider_value_changed)
self.max_slider.valueChanged.connect(self.on_max_slider_value_changed)
self.min_slider_edit.returnPressed.connect(self.on_min_slider_edit_return_pressed)
self.max_slider_edit.returnPressed.connect(self.on_max_slider_edit_return_pressed)
# Tooltips
self.slider.setToolTip('Image index in sequence')
self.slider_edit.setToolTip(self.slider.toolTip())
self.min_slider.setToolTip('Black point')
self.min_slider_edit.setToolTip(self.min_slider.toolTip())
self.max_slider.setToolTip('White point')
self.max_slider_edit.setToolTip(self.min_slider.toolTip())
mainLayout = QGridLayout()
mainLayout.addWidget(self.label, 0, 0, 1, 2)
mainLayout.addWidget(self.slider_edit, 1, 0)
mainLayout.addWidget(self.slider, 1, 1)
mainLayout.addWidget(self.min_slider_edit, 2, 0)
mainLayout.addWidget(self.min_slider, 2, 1)
mainLayout.addWidget(self.max_slider_edit, 3, 0)
mainLayout.addWidget(self.max_slider, 3, 1)
self.setLayout(mainLayout)
def contextMenuEvent(self, event):
contextMenu = QMenu(self)
reset_action = contextMenu.addAction('Reset')
auto_levels_action = contextMenu.addAction('Auto Levels')
new_image_auto_levels = contextMenu.addAction('Auto Levels on New Image')
new_image_auto_levels.setCheckable(True)
new_image_auto_levels.setChecked(self.new_image_auto_levels)
pop_action = None
save_action = None
try:
import pyqtgraph
if self._images is not None and not self.popup_visible:
pop_action = contextMenu.addAction('Pop Up')
except:
LOG.debug('pyqtgraph not installed, pop up option disabled')
try:
import imageio
if self._images is not None:
save_action = contextMenu.addAction('Save')
except:
LOG.debug('imageio not installed, save option disabled')
action = contextMenu.exec_(self.mapToGlobal(event.pos()))
if not action:
return
if action == save_action:
file_name, _ = QFileDialog.getSaveFileName(None,
"Select File Name",
self._last_save_dir,
"Images (*.tif *.png *.jpg)")
if file_name:
if not os.path.splitext(file_name)[1]:
file_name += '.tif'
self._last_save_dir = os.path.dirname(file_name)
if self._images.shape[0] == 1:
imageio.imsave(file_name, self._images[0])
else:
if os.path.splitext(file_name)[1] != '.tif':
raise ImageViewingError('3D data can be stored only in tif format')
# bigtiff size from tifffile
imageio.volsave(file_name, self._images,
bigtiff=self._images.nbytes > 2 ** 32 - 2 ** 25)
elif action == reset_action:
self.reset_clim()
elif action == auto_levels_action:
self.reset_clim(auto=True)
elif action == new_image_auto_levels:
self.new_image_auto_levels = action.isChecked()
elif action == pop_action:
self.popup()
@property
def images(self):
return self._images
@images.setter
def images(self, images):
was_none = self._images is None
self._images = images
if self._images is None:
self.screen_image.image = None
self.set_enabled_adjustments(False)
return
self.set_enabled_adjustments(True)
if self._images.ndim == 2:
self._images = self._images[np.newaxis, :, :]
if self._images.shape[0] == 1:
self.slider.hide()
self.slider_edit.hide()
else:
self.slider.setMaximum(len(self._images) - 1)
self.slider.show()
self.slider_edit.show()
self.slider_edit.setText('0')
self.slider.blockSignals(True)
self.slider.setValue(0)
self.slider.blockSignals(False)
if self._pg_window is not None:
self._update_pg_window_images()
self._update_pg_window_index()
self.screen_image.image = self._images[0]
if was_none or self.new_image_auto_levels:
self.reset_clim(auto=True)
else:
self.label.updateImage()
validator = self.min_slider_edit.validator()
if validator is None:
validator = QtGui.QDoubleValidator(self.screen_image.minimum,
self.screen_image.maximum, 100)
self.min_slider_edit.setValidator(validator)
self.max_slider_edit.setValidator(validator)
else:
validator.setRange(self.screen_image.minimum, self.screen_image.maximum, 100)
self.slider_edit.validator().setTop(self.slider.maximum())
if self.label.width() < 256 or self.label.height() < 256:
self.label.resize(256, 256)
def append(self, images):
if self.images is None:
self.images = images
else:
if images.ndim == 2:
images = images[np.newaxis, :, :]
if images.shape[1:] != self.images.shape[1:]:
raise ImageViewingError('Appended images have different shape '
f'{images.shape[1:]} than the displayed ones '
f'{self.images.shape[1:]}')
self.images = np.concatenate((self.images, images))
def set_enabled_adjustments(self, enabled):
self.slider.setEnabled(enabled)
self.slider_edit.setEnabled(enabled)
self.min_slider.setEnabled(enabled)
self.min_slider_edit.setEnabled(enabled)
self.max_slider.setEnabled(enabled)
self.max_slider_edit.setEnabled(enabled)
def reset_clim(self, auto=False):
self.screen_image.reset()
if auto:
self.screen_image.auto_levels()
self.min_slider_edit.setText('{:g}'.format(self.screen_image.black_point))
self.max_slider_edit.setText('{:g}'.format(self.screen_image.white_point))
self.set_slider_value(self.min_slider, self.screen_image.black_point)
self.set_slider_value(self.max_slider, self.screen_image.white_point)
self.label.updateImage()
self._update_pg_window_lut()
@property
def popup_visible(self):
return self._pg_window and self._pg_window.isVisible()
def popup(self):
import pyqtgraph
pyqtgraph.setConfigOptions(antialias=True, imageAxisOrder='row-major')
if self._pg_window is not None:
if not self._pg_window.isVisible():
self._pg_window.show()
return
def on_pg_window_time_changed(index, time):
self._set_index(index)
self.slider.blockSignals(True)
self.slider_edit.setText(str(index))
self.slider.setValue(index)
self.slider.blockSignals(False)
def on_pg_window_levels_changed(hist_item):
minimum, maximum = hist_item.getLevels()
if (self.screen_image.minimum <= minimum <= self.screen_image.maximum
and self.screen_image.minimum <= maximum <= self.screen_image.maximum):
self.min_slider_edit.setText('{:g}'.format(minimum))
self.set_slider_value(self.min_slider, minimum)
self.max_slider_edit.setText('{:g}'.format(maximum))
self.set_slider_value(self.max_slider, maximum)
self.screen_image.black_point = minimum
self.screen_image.white_point = maximum
self.label.updateImage()
def pg_mouse_moved(ev):
if self._pg_window.imageItem.sceneBoundingRect().contains(ev):
pos = self._pg_window.imageItem.mapFromScene(ev)
x = int(pos.x() + 0.5)
y = int(pos.y() + 0.5)
self._pg_window.view.setTitle('x={}, y={}, I={:g}'.format(x, y,
self._pg_window.imageItem.image[y, x]))
else:
self._pg_window.view.setTitle('')
self._pg_window = pyqtgraph.ImageView(view=pyqtgraph.PlotItem())
self._pg_window.imageItem.scene().sigMouseMoved.connect(pg_mouse_moved)
self._pg_window.setWindowFlag(Qt.SubWindow, True)
self._update_pg_window_images()
self._update_pg_window_index()
self._update_pg_window_lut()
self._pg_window.show()
self._pg_window.sigTimeChanged.connect(on_pg_window_time_changed)
self._pg_window.ui.histogram.item.sigLevelsChanged.connect(on_pg_window_levels_changed)
def cleanup(self):
if self._pg_window:
self._pg_window.close()
self._pg_window = None
def _set_index(self, index):
self.screen_image.image = self.images[index]
self.label.updateImage()
def _update_pg_window_images(self):
if self.images.shape[0] == 1:
im_to_set = self.images[0]
else:
im_to_set = self.images
self._pg_window.setImage(im_to_set, autoLevels=False)
def _update_pg_window_index(self):
if self._images.shape[0] > 1 and self._pg_window is not None:
self._pg_window.blockSignals(True)
self._pg_window.setCurrentIndex(self.slider.value())
self._pg_window.blockSignals(False)
def _update_pg_window_lut(self):
if self._pg_window is not None:
self._pg_window.ui.histogram.item.blockSignals(True)
self._pg_window.setLevels(self.screen_image.black_point, self.screen_image.white_point)
self._pg_window.ui.histogram.item.blockSignals(False)
def on_slider_value_changed(self, value):
self._set_index(value)
self.slider_edit.setText(str(value))
self._update_pg_window_index()
def on_slider_edit_return_pressed(self):
self.slider.setValue(int(self.slider_edit.text()))
def on_min_slider_edit_return_pressed(self):
value = float(self.min_slider_edit.text())
if value < self.screen_image.white_point:
self.screen_image.black_point = value
self.set_slider_value(self.min_slider, value)
self.label.updateImage()
self._update_pg_window_lut()
def on_max_slider_edit_return_pressed(self):
value = float(self.max_slider_edit.text())
if value > self.screen_image.black_point:
self.screen_image.white_point = value
self.set_slider_value(self.max_slider, value)
self.label.updateImage()
self._update_pg_window_lut()
def on_min_slider_value_changed(self, value):
self.screen_image.set_black_point_normalized(value)
self.min_slider_edit.setText('{:g}'.format(self.screen_image.black_point))
self.label.updateImage()
self._update_pg_window_lut()
def on_max_slider_value_changed(self, value):
self.screen_image.set_white_point_normalized(value)
self.max_slider_edit.setText('{:g}'.format(self.screen_image.white_point))
self.label.updateImage()
self._update_pg_window_lut()
def set_slider_value(self, slider, value):
slider.blockSignals(True)
slider.setValue(int(self.screen_image.convert_native_value_to_normalized(value)))
slider.blockSignals(False)
class ImageViewingError(FlowError):
pass
tofu-0.12.0/tofu/genreco.py 0000664 0000000 0000000 00000073356 14237137211 0015537 0 ustar 00root root 0000000 0000000 """General projection-based reconstruction for tomographic/laminographic cone/parallel beam data
sets.
"""
import copy
import itertools
import logging
import os
import time
import numpy as np
from multiprocessing.pool import ThreadPool
from gi.repository import Ufo
from .preprocess import create_preprocessing_pipeline
from .util import (get_filtering_padding, get_reconstructed_cube_shape,
get_reconstruction_regions, get_filenames, determine_shape,
get_scarray_value, Vector)
from .tasks import get_task, get_writer
LOG = logging.getLogger(__name__)
DTYPE_CL_SIZE = {'float': 4,
'double': 8,
'half': 2,
'uchar': 1,
'ushort': 2,
'uint': 4}
def genreco(args):
st = time.time()
if is_output_single_file(args):
try:
import ufo.numpy
except ImportError:
LOG.error('You must install ufo-python-tools to be able to write single-file output')
return
if (args.energy is not None and args.propagation_distance is not None and not
(args.projection_margin or args.disable_projection_crop)):
LOG.warning('Phase retrieval without --projection-margin specification or '
'--disable-projection-crop may cause convolution artifacts')
_fill_missing_args(args)
_convert_angles_to_rad(args)
set_projection_filter_scale(args)
x_region, y_region, z_region = get_reconstruction_regions(args, store=True, dtype=float)
vol_shape = get_reconstructed_cube_shape(x_region, y_region, z_region)
bpp = DTYPE_CL_SIZE[args.store_type]
num_voxels = vol_shape[0] * vol_shape[1] * vol_shape[2]
vol_nbytes = num_voxels * bpp
resources = [Ufo.Resources()]
gpus = np.array(resources[0].get_gpu_nodes())
gpu_indices = np.array(args.gpus or list(range(len(gpus))))
if min(gpu_indices) < 0 or max(gpu_indices) > len(gpus) - 1:
raise ValueError('--gpus contains invalid indices')
gpus = gpus[gpu_indices]
duration = 0
for i, gpu in enumerate(gpus):
print('Max mem for {}: {:.2f} GB'.format(i, gpu.get_info(0) / 2. ** 30))
runs = make_runs(gpus, gpu_indices, x_region, y_region, z_region, bpp,
slices_per_device=args.slices_per_device,
slice_memory_coeff=args.slice_memory_coeff,
data_splitting_policy=args.data_splitting_policy,
num_gpu_threads=args.num_gpu_threads)
for i in range(len(runs[0]) - 1):
resources.append(Ufo.Resources())
LOG.info('Number of passes: %d', len(runs))
LOG.debug('GPUs and regions:')
for regions in runs:
LOG.debug('%s', str(regions))
for i, regions in enumerate(runs):
duration += _run(resources, args, x_region, y_region, regions, i, vol_nbytes)
num_gupdates = num_voxels * args.number * 1e-9
total_duration = time.time() - st
LOG.debug('UFO duration: %.2f s', duration)
LOG.debug('Total duration: %.2f s', total_duration)
LOG.debug('UFO performance: %.2f GUPS', num_gupdates / duration)
LOG.debug('Total performance: %.2f GUPS', num_gupdates / total_duration)
def make_runs(gpus, gpu_indices, x_region, y_region, z_region, bpp, slices_per_device=None,
slice_memory_coeff=0.8, data_splitting_policy='one', num_gpu_threads=1):
gpu_indices = np.array(gpu_indices)
def _add_region(runs, gpu_index, current, to_process, z_start, z_step):
current_per_thread = current // num_gpu_threads
for i in range(num_gpu_threads):
if i + 1 == num_gpu_threads:
current_per_thread += current % num_gpu_threads
z_end = z_start + current_per_thread * z_step
runs[-1].append((gpu_indices[gpu_index], [z_start, z_end, z_step]))
z_start = z_end
return z_start, z_end, to_process - current
z_start, z_stop, z_step = z_region
y_start, y_stop, y_step = y_region
x_start, x_stop, x_step = x_region
slice_width, slice_height, num_slices = get_reconstructed_cube_shape(x_region, y_region,
z_region)
if slices_per_device:
slices_per_device = [slices_per_device for i in range(len(gpus))]
else:
slices_per_device = get_num_slices_per_gpu(gpus, slice_width, slice_height, bpp,
slice_memory_coeff=slice_memory_coeff)
max_slices_per_pass = sum(slices_per_device)
if not max_slices_per_pass:
raise RuntimeError('None of the available devices has enough memory to store any slices')
num_full_passes = num_slices // max_slices_per_pass
LOG.debug('Number of slices: %d', num_slices)
LOG.debug('Slices per device %s', slices_per_device)
LOG.debug('Maximum slices on all GPUs per pass: %d', max_slices_per_pass)
LOG.debug('Number of passes with full workload: %d', num_slices // max_slices_per_pass)
sorted_indices = np.argsort(slices_per_device)[-np.count_nonzero(slices_per_device):]
runs = []
z_start = z_region[0]
to_process = num_slices
# Create passes where all GPUs are fully loaded
for j in range(num_full_passes):
runs.append([])
for i in sorted_indices:
z_start, z_end, to_process = _add_region(runs, i, slices_per_device[i], to_process,
z_start, z_step)
if to_process:
if data_splitting_policy == 'one':
# Fill the last pass by maximizing the workload per GPU
runs.append([])
for i in sorted_indices[::-1]:
if not to_process:
break
current = min(slices_per_device[i], to_process)
z_start, z_end, to_process = _add_region(runs, i, current, to_process,
z_start, z_step)
else:
# Fill the last pass by maximizing the number of GPUs which will work
num_gpus = len(sorted_indices)
runs.append([])
for j, i in enumerate(sorted_indices):
# Current GPU will either process the maximum number of slices it can. If the number
# of slices per GPU based on even division between them cannot saturate the GPU, use
# this number. This way the work will be split evenly between the GPUs.
current = max(min(slices_per_device[i], (to_process - 1) // (num_gpus - j) + 1), 1)
z_start, z_end, to_process = _add_region(runs, i, current, to_process,
z_start, z_step)
if not to_process:
break
return runs
def get_num_slices_per_gpu(gpus, width, height, bpp, slice_memory_coeff=0.8):
num_slices = []
slice_size = width * height * bpp
for i, gpu in enumerate(gpus):
max_mem = gpu.get_info(Ufo.GpuNodeInfo.GLOBAL_MEM_SIZE)
num_slices.append(int(np.floor(max_mem * slice_memory_coeff / slice_size)))
return num_slices
def _run(resources, args, x_region, y_region, regions, run_number, vol_nbytes):
"""Execute one pass on all possible GPUs with slice ranges given by *regions*. Use separate
thread per GPU and optimize the read projection regions.
"""
executors = []
for index in range(len(regions)):
gpu_index, region = regions[index]
region_index = run_number * len(resources) + index
executors.append(Executor(resources[index], args, region, x_region, y_region,
gpu_index, region_index))
def start_one(index):
return executors[index].process()
st = time.time()
pool = ThreadPool(processes=len(regions))
result = pool.map_async(start_one, list(range(len(regions))))
if is_output_single_file(args):
import tifffile
bigtiff = vol_nbytes > 2 ** 32 - 1
LOG.debug('Writing BigTiff: %s', bigtiff)
dirname = os.path.dirname(args.output)
if dirname and not os.path.exists(dirname):
os.makedirs(dirname)
with tifffile.TiffWriter(args.output, append=run_number != 0, bigtiff=bigtiff) as writer:
for executor in executors:
executor.consume(writer)
result.wait()
return time.time() - st
def setup_graph(args, graph, x_region, y_region, region, source=None, gpu=None, do_output=True,
index=0, make_reader=True):
backproject = get_task('general-backproject', processing_node=gpu)
if do_output:
if args.dry_run:
sink = get_task('null', processing_node=gpu, download=True)
else:
sink = get_writer(args)
sink.props.filename = '{}-{:>03}-%04i.tif'.format(args.output, index)
width = args.width
height = args.height
if args.transpose_input:
tmp = width
width = height
height = tmp
if args.projection_filter != 'none' and args.projection_crop_after == 'backprojection':
# Take projection padding into account
padding = get_filtering_padding(width)
args.center_position_x = [pos + padding / 2 for pos in args.center_position_x]
if args.z_parameter == 'center-position-x':
region = [region[0] + padding / 2, region[1] + padding / 2, region[2]]
LOG.debug('center-position-x after padding: %g', args.center_position_x[0])
backproject.props.parameter = args.z_parameter
if args.burst:
backproject.props.burst = args.burst
backproject.props.z = args.z
backproject.props.region = region
backproject.props.x_region = x_region
backproject.props.y_region = y_region
backproject.props.center_position_x = args.center_position_x
backproject.props.center_position_z = args.center_position_z
backproject.props.source_position_x = args.source_position_x
backproject.props.source_position_y = args.source_position_y
backproject.props.source_position_z = args.source_position_z
backproject.props.detector_position_x = args.detector_position_x
backproject.props.detector_position_y = args.detector_position_y
backproject.props.detector_position_z = args.detector_position_z
backproject.props.detector_angle_x = args.detector_angle_x
backproject.props.detector_angle_y = args.detector_angle_y
backproject.props.detector_angle_z = args.detector_angle_z
backproject.props.axis_angle_x = args.axis_angle_x
backproject.props.axis_angle_y = args.axis_angle_y
backproject.props.axis_angle_z = args.axis_angle_z
backproject.props.volume_angle_x = args.volume_angle_x
backproject.props.volume_angle_y = args.volume_angle_y
backproject.props.volume_angle_z = args.volume_angle_z
backproject.props.num_projections = args.number
backproject.props.compute_type = args.compute_type
backproject.props.result_type = args.result_type
backproject.props.store_type = args.store_type
backproject.props.overall_angle = args.overall_angle
backproject.props.addressing_mode = args.genreco_padding_mode
backproject.props.gray_map_min = args.slice_gray_map[0]
backproject.props.gray_map_max = args.slice_gray_map[1]
source = create_preprocessing_pipeline(args, graph, source=source,
processing_node=gpu,
cone_beam_weight=not args.disable_cone_beam_weight,
make_reader=make_reader)
if source:
graph.connect_nodes(source, backproject)
else:
source = backproject
if do_output:
graph.connect_nodes(backproject, sink)
last = sink
else:
last = backproject
return (source, last)
def is_output_single_file(args):
filename = args.output.lower()
return not args.dry_run and (filename.endswith('.tif') or filename.endswith('.tiff'))
def set_projection_filter_scale(args):
is_parallel = np.all(np.isinf(args.source_position_y))
magnification = (args.source_position_y[0] - args.detector_position_y[0]) / \
args.source_position_y[0]
args.projection_filter_scale = 1.
if is_parallel:
if np.any(np.array(args.axis_angle_x)):
LOG.debug('Adjusting filter for parallel beam laminography')
args.projection_filter_scale = 0.5 * np.cos(args.axis_angle_x[0])
else:
args.projection_filter_scale = 0.5
args.projection_filter_scale /= magnification ** 2
if np.all(np.array(args.axis_angle_x) == 0):
LOG.debug('Adjusting filter for cone beam tomography')
args.projection_filter_scale /= magnification
def _fill_missing_args(args):
(width, height) = determine_shape(args, args.projections, store=False)
if args.transpose_input:
tmp = width
width = height
height = tmp
args.center_position_x = (args.center_position_x or [width / 2.])
args.center_position_z = (args.center_position_z or [height / 2.])
if not args.overall_angle:
args.overall_angle = 360.
LOG.info('Overall angle not specified, using 360 deg')
if not args.number:
if len(args.axis_angle_z) > 1:
LOG.debug("--number not specified, using length of --axis-angle-z: %d",
len(args.axis_angle_z))
args.number = len(args.axis_angle_z)
else:
num_files = len(get_filenames(args.projections))
if not num_files:
raise RuntimeError("No files found in `{}'".format(args.projections))
LOG.debug("--number not specified, using number of files matching "
"--projections pattern: %d", num_files)
args.number = num_files
if args.dry_run:
if not args.number:
raise ValueError('--number must be specified by --dry-run')
determine_shape(args, args.projections, store=True)
LOG.info('Dummy data W x H x N: {} x {} x {}'.format(args.width,
args.height,
args.number))
return args
def _convert_angles_to_rad(args):
names = ['detector_angle', 'axis_angle', 'volume_angle']
coords = ['x', 'y', 'z']
angular_z_params = [x[0].replace('_', '-') + '-' + x[1] for x in itertools.product(names, coords)]
args.overall_angle = np.deg2rad(args.overall_angle)
if args.z_parameter in angular_z_params:
LOG.debug('Converting z parameter values to radians')
args.region = _convert_list_to_rad(args.region)
for name in names:
for coord in coords:
full_name = name + '_' + coord
values = getattr(args, full_name)
setattr(args, full_name, _convert_list_to_rad(values))
def _convert_list_to_rad(values):
return np.deg2rad(np.array(values)).tolist()
def _are_values_equal(values):
return np.all(np.array(values) == values[0])
class Executor(object):
def __init__(self, resources, args, region, x_region, y_region, gpu_index, region_index):
self.resources = resources
self.args = args
self.region = region
self.gpu_index = gpu_index
self.x_region = x_region
self.y_region = y_region
self.region_index = region_index
self.single_file_output = is_output_single_file(self.args)
self.output = Ufo.OutputTask() if self.single_file_output else None
def process(self):
scheduler = Ufo.FixedScheduler()
scheduler.set_resources(self.resources)
graph = Ufo.TaskGraph()
gpu = scheduler.get_resources().get_gpu_nodes()[self.gpu_index]
geometry = CTGeometry(self.args)
if (len(self.args.center_position_z) == 1 and
np.modf(self.args.center_position_z[0])[0] == 0 and
geometry.is_simple_parallel_tomo):
LOG.info('Simple tomography with integer z center, changing to center_position_z + 0.5 '
'to avoid interpolation')
geometry.args.center_position_z = (geometry.args.center_position_z[0] + 0.5,)
if not self.args.disable_projection_crop:
if not self.args.dry_run and (self.args.y or self.args.height or
self.args.transpose_input):
LOG.debug('--y or --height or --transpose-input specified, '
'not optimizing projection region')
else:
geometry.optimize_args(region=self.region)
opt_args = geometry.args
if self.args.dry_run:
source = get_task('dummy-data', number=self.args.number, width=self.args.width,
height=self.args.height)
else:
source = None
last = setup_graph(opt_args, graph, self.x_region, self.y_region, self.region,
source=source, gpu=gpu, index=self.region_index, make_reader=True,
do_output=not self.single_file_output)[-1]
if self.single_file_output:
graph.connect_nodes(last, self.output)
LOG.debug('Device: %d, region: %s', self.gpu_index, self.region)
scheduler.run(graph)
return scheduler.props.time
def consume(self, writer):
import ufo.numpy
for i in np.arange(*self.region):
buf = self.output.get_output_buffer()
writer.save(ufo.numpy.asarray(buf))
self.output.release_output_buffer(buf)
class CTGeometry(object):
def __init__(self, args):
self.args = copy.deepcopy(args)
determine_shape(self.args, self.args.projections, store=True)
get_reconstruction_regions(self.args, store=True, dtype=float)
self.args.center_position_x = (self.args.center_position_x or [self.args.width / 2.])
self.args.center_position_z = (self.args.center_position_z or [self.args.height / 2.])
@property
def is_parallel(self):
return np.all(np.isinf(self.args.source_position_y))
@property
def is_detector_rotated(self):
return (np.any(self.args.detector_angle_x) or
np.any(self.args.detector_angle_y) or
np.any(self.args.detector_angle_z))
@property
def is_axis_rotated(self):
return (np.any(self.args.axis_angle_x) or
np.any(self.args.axis_angle_y) or
np.any(self.args.axis_angle_z))
@property
def is_volume_rotated(self):
return (np.any(self.args.volume_angle_x) or
np.any(self.args.volume_angle_y) or
np.any(self.args.volume_angle_z))
@property
def is_center_position_x_constant(self):
return _are_values_equal(self.args.center_position_x)
@property
def is_center_position_z_constant(self):
return _are_values_equal(self.args.center_position_z)
@property
def is_center_constant(self):
return self.is_center_position_x_constant and self.is_center_position_z_constant
@property
def is_simple_parallel_tomo(self):
return (not (self.is_axis_rotated or self.is_detector_rotated or
self.is_volume_rotated) and self.is_parallel and
self.is_center_constant)
def optimize_args(self, region=None):
xmin, ymin, xmax, ymax = self.compute_height(region=region)
center_position_z = np.array(self.args.center_position_z) - ymin
self.args.center_position_z = center_position_z.tolist()
self.args.y = int(ymin)
self.args.height = int(ymax - ymin)
LOG.debug('Optimization for region: %s', region or self.args.region)
LOG.debug('Optimized X: %d - %d, Z: %d - %d', xmin, xmax, ymin, ymax)
LOG.debug('Optimized Z: %d', self.args.y)
LOG.debug('Optimized height: %d', self.args.height)
LOG.debug('Optimized center_position_z: %g - %g', self.args.center_position_z[0],
self.args.center_position_z[-1])
def compute_height(self, region=None):
extrema = []
if not region:
region = self.args.region
if self.is_simple_parallel_tomo:
# Simple parallel beam tomography, thus compute only the horizontal crop at rotations
# which are multiples of 45 degrees
LOG.debug('Computing optimal projection region from 4 angles')
projs_per_45 = self.args.number / self.args.overall_angle * np.pi / 4
stop = 4 if self.args.overall_angle <= np.pi else 8
indices = projs_per_45 * np.arange(1, stop, 2)
indices = np.round(indices).astype(np.int).tolist()
else:
LOG.debug('Computing optimal projection region from all angles')
indices = list(range(self.args.number))
for i in indices:
extrema_0 = self._compute_one_parameter(region[0], i)
extrema_1 = self._compute_one_parameter(region[1], i)
extrema.append(extrema_0)
extrema.append(extrema_1)
minima = np.min(extrema, axis=0)
maxima = np.max(extrema, axis=0)
if maxima[-1] == minima[2]:
# Don't let height be 0
maxima[-1] += 1
result = tuple(minima[::2]) + tuple(maxima[1::2])
return result
def _compute_one_parameter(self, param_value, index):
source_position = np.array([get_scarray_value(self.args.source_position_x, index),
get_scarray_value(self.args.source_position_y, index),
get_scarray_value(self.args.source_position_z, index)])
axis = Vector(x_angle=get_scarray_value(self.args.axis_angle_x, index),
y_angle=get_scarray_value(self.args.axis_angle_y, index),
z_angle=get_scarray_value(self.args.axis_angle_z, index),
position=[get_scarray_value(self.args.center_position_x, index),
0,
get_scarray_value(self.args.center_position_z, index)])
detector = Vector(x_angle=get_scarray_value(self.args.detector_angle_x, index),
y_angle=get_scarray_value(self.args.detector_angle_y, index),
z_angle=get_scarray_value(self.args.detector_angle_z, index),
position=[get_scarray_value(self.args.detector_position_x, index),
get_scarray_value(self.args.detector_position_y, index),
get_scarray_value(self.args.detector_position_z, index)])
volume_angle = Vector(x_angle=get_scarray_value(self.args.volume_angle_x, index),
y_angle=get_scarray_value(self.args.volume_angle_y, index),
z_angle=get_scarray_value(self.args.volume_angle_z, index))
z = self.args.z
if self.args.z_parameter == 'z':
z = param_value
elif self.args.z_parameter == 'axis-angle-x':
axis.x_angle = param_value
elif self.args.z_parameter == 'axis-angle-y':
axis.y_angle = param_value
elif self.args.z_parameter == 'axis-angle-z':
axis.z_angle = param_value
elif self.args.z_parameter == 'volume-angle-x':
volume_angle.x_angle = param_value
elif self.args.z_parameter == 'volume-angle-y':
volume_angle.y_angle = param_value
elif self.args.z_parameter == 'volume-angle-z':
volume_angle.z_angle = param_value
elif self.args.z_parameter == 'detector-angle-x':
detector.x_angle = param_value
elif self.args.z_parameter == 'detector-angle-y':
detector.y_angle = param_value
elif self.args.z_parameter == 'detector-angle-z':
detector.z_angle = param_value
elif self.args.z_parameter == 'detector-position-x':
detector.position[0] = param_value
elif self.args.z_parameter == 'detector-position-y':
detector.position[1] = param_value
elif self.args.z_parameter == 'detector-position-z':
detector.position[2] = param_value
elif self.args.z_parameter == 'source-position-x':
source_position[0] = param_value
elif self.args.z_parameter == 'source-position-y':
source_position[1] = param_value
elif self.args.z_parameter == 'source-position-z':
source_position[2] = param_value
elif self.args.z_parameter == 'center-position-x':
axis.position[0] = param_value
elif self.args.z_parameter == 'center-position-z':
axis.position[2] = param_value
else:
raise RuntimeError("Unknown z parameter '{}'".format(self.args.z_parameter))
points = get_extrema(self.args.x_region, self.args.y_region, z)
if self.args.z_parameter != 'z':
points_upper = get_extrema(self.args.x_region, self.args.y_region, z + 1)
points = np.hstack((points, points_upper))
tomo_angle = float(index) / self.args.number * self.args.overall_angle
xe, ye = compute_detector_pixels(points, source_position, axis, volume_angle, detector,
tomo_angle)
return compute_detector_region(xe, ye, (self.args.height, self.args.width),
overhead=self.args.projection_margin)
def project(points, source, detector_normal, detector_offset):
"""Project *points* onto a detector."""
x, y, z = points
source_extended = np.tile(source[:, np.newaxis], [1, points.shape[1]])
detector_normal_extended = np.tile(detector_normal[:, np.newaxis], [1, points.shape[1]])
denom = np.sum((points - source_extended) * detector_normal_extended, axis=0)
if np.isinf(source[1]):
# Parallel beam
if np.any(detector_normal != np.array([0., -1, 0])):
# Detector is not perpendicular, compute translation along the beam direction,
# otherwise don't compute anything because voxels are mapped directly
# to detector coordinates
points[1, :] = - (detector_offset +
detector_normal[0] * points[0, :] +
detector_normal[2] * points[2, :]) / detector_normal[1]
projected = points
else:
# Cone beam
u = -(detector_offset + np.dot(source, detector_normal)) / denom
u = np.tile(u, [3, 1])
projected = source_extended + (points - source_extended) * u
return projected
def compute_detector_pixels(points, source_position, axis, volume_rotation, detector, tomo_angle):
"""*points* are a list of points along x-direcion, thus the array has height 3.
*source_position* is a 3-vector, *axis*, *volume_rotation* and *detector* are util.Vector
instances.
"""
# Rotate the axis
detector_normal = np.array((0, -1, 0), dtype=np.float)
detector_normal = rotate_z(detector.z_angle, detector_normal)
detector_normal = rotate_y(detector.y_angle, detector_normal)
detector_normal = rotate_x(detector.x_angle, detector_normal)
# Compute d from ax + by + cz + d = 0
detector_offset = -np.dot(detector.position, detector_normal)
if np.isinf(source_position[1]):
# Parallel beam
voxels = points
else:
# Apply magnification
voxels = -points * source_position[1] / (detector.position[1] - source_position[1])
# Rotate the volume
voxels = rotate_z(volume_rotation.z_angle, voxels)
voxels = rotate_y(volume_rotation.y_angle, voxels)
voxels = rotate_x(volume_rotation.x_angle, voxels)
# Rotate around the axis
voxels = rotate_z(tomo_angle, voxels)
# Rotate the volume
voxels = rotate_z(axis.z_angle, voxels)
voxels = rotate_y(axis.y_angle, voxels)
voxels = rotate_x(axis.x_angle, voxels)
# Get the projected pixel
projected = project(voxels, source_position, detector_normal, detector_offset)
if np.any(detector_normal != np.array([0., -1, 0])):
# Detector is not perpendicular
projected -= np.array([detector.position]).T
# Reverse rotation => reverse order of transformation matrices and negative angles
projected = rotate_x(-detector.x_angle, projected)
projected = rotate_y(-detector.y_angle, projected)
projected = rotate_z(-detector.z_angle, projected)
x = projected[0, :] + axis.position[0] - 0.5
y = projected[2, :] + axis.position[2] - 0.5
return x, y
def compute_detector_region(x, y, shape, overhead=2):
"""*overhead* specifies how much margin is taken into account around the computed area."""
def _compute_outlier(extremum_func, values):
if extremum_func == min:
round_func = np.floor
sgn = -1
else:
round_func = np.ceil
sgn = +1
return int(round_func(extremum_func(values)) + sgn * overhead)
x_min = min(shape[1], max(0, _compute_outlier(min, x)))
y_min = min(shape[0], max(0, _compute_outlier(min, y)))
x_max = max(0, min(shape[1], _compute_outlier(max, x)))
y_max = max(0, min(shape[0], _compute_outlier(max, y)))
return (x_min, x_max, y_min, y_max)
def get_extrema(x_region, y_region, z):
def get_extrema(region):
return (region[0], region[1])
product = itertools.product(get_extrema(x_region), get_extrema(y_region), [z])
return np.array(list(product), dtype=np.float).T.copy()
def rotate_x(angle, point):
cos = np.cos(angle)
sin = np.sin(angle)
matrix = np.identity(3)
matrix[1, 1] = cos
matrix[1, 2] = -sin
matrix[2, 1] = sin
matrix[2, 2] = cos
return np.dot(matrix, point)
def rotate_y(angle, point):
cos = np.cos(angle)
sin = np.sin(angle)
matrix = np.identity(3)
matrix[0, 0] = cos
matrix[0, 2] = sin
matrix[2, 0] = -sin
matrix[2, 2] = cos
return np.dot(matrix, point)
def rotate_z(angle, point):
cos = np.cos(angle)
sin = np.sin(angle)
matrix = np.identity(3)
matrix[0, 0] = cos
matrix[0, 1] = -sin
matrix[1, 0] = sin
matrix[1, 1] = cos
return np.dot(matrix, point)
tofu-0.12.0/tofu/gui.py 0000664 0000000 0000000 00000054571 14237137211 0014677 0 ustar 00root root 0000000 0000000 import sys
import os
import logging
import numpy as np
import tifffile
import pkg_resources
from argparse import ArgumentParser
from contextlib import contextmanager
from . import reco, config, util, __version__
try:
import tofu.vis.qt
from PyQt4 import QtGui, QtCore, uic
except ImportError:
raise ImportError("Cannot import modules for GUI, please install PyQt4 and pyqtgraph")
LOG = logging.getLogger(__name__)
def set_last_dir(path, line_edit, last_dir):
if os.path.exists(str(path)):
line_edit.clear()
line_edit.setText(path)
last_dir = str(line_edit.text())
return last_dir
def get_filtered_filenames(path, exts=['.tif', '.edf']):
result = []
try:
for ext in exts:
result += [os.path.join(path, f) for f in os.listdir(path) if f.endswith(ext)]
except OSError:
return []
return sorted(result)
@contextmanager
def spinning_cursor():
QtGui.QApplication.setOverrideCursor(QtGui.QCursor(QtCore.Qt.WaitCursor))
yield
QtGui.QApplication.restoreOverrideCursor()
class CallableHandler(logging.Handler):
def __init__(self, func):
logging.Handler.__init__(self)
self.func = func
def emit(self, record):
self.func(self.format(record))
class ApplicationWindow(QtGui.QMainWindow):
def __init__(self, app, params):
QtGui.QMainWindow.__init__(self)
self.params = params
self.app = app
ui_file = pkg_resources.resource_filename(__name__, 'gui.ui')
self.ui = uic.loadUi(ui_file, self)
self.ui.show()
self.ui.setAttribute(QtCore.Qt.WA_DeleteOnClose)
self.ui.tab_widget.setCurrentIndex(0)
self.ui.slice_dock.setVisible(False)
self.ui.volume_dock.setVisible(False)
self.ui.axis_view_widget.setVisible(False)
self.slice_viewer = None
self.volume_viewer = None
self.overlap_viewer = tofu.vis.qt.OverlapViewer()
self.get_values_from_params()
log_handler = CallableHandler(self.on_log_record)
log_handler.setLevel(logging.DEBUG)
log_handler.setFormatter(logging.Formatter('%(name)s: %(message)s'))
root_logger = logging.getLogger('')
root_logger.setLevel(logging.DEBUG)
root_logger.handlers = [log_handler]
self.ui.input_path_button.setToolTip('Path to projections or sinograms')
self.ui.proj_button.setToolTip('Denote if path contains projections')
self.ui.y_step.setToolTip(self.get_help('reading', 'y-step'))
self.ui.method_box.setToolTip(self.get_help('tomographic-reconstruction', 'method'))
self.ui.axis_spin.setToolTip(self.get_help('tomographic-reconstruction', 'axis'))
self.ui.angle_step.setToolTip(self.get_help('reconstruction', 'angle'))
self.ui.angle_offset.setToolTip(self.get_help('tomographic-reconstruction', 'offset'))
self.ui.oversampling.setToolTip(self.get_help('dfi', 'oversampling'))
self.ui.iterations_sart.setToolTip(self.get_help('ir', 'num-iterations'))
self.ui.relaxation.setToolTip(self.get_help('sart', 'relaxation-factor'))
self.ui.output_path_button.setToolTip(self.get_help('general', 'output'))
self.ui.ffc_box.setToolTip(self.get_help('gui', 'ffc-correction'))
self.ui.interpolate_button.setToolTip('Interpolate between two sets of flat fields')
self.ui.darks_path_button.setToolTip(self.get_help('flat-correction', 'darks'))
self.ui.flats_path_button.setToolTip(self.get_help('flat-correction', 'flats'))
self.ui.flats2_path_button.setToolTip(self.get_help('flat-correction', 'flats2'))
self.ui.path_button_0.setToolTip(self.get_help('gui', 'deg0'))
self.ui.path_button_180.setToolTip(self.get_help('gui', 'deg180'))
self.ui.input_path_button.clicked.connect(self.on_input_path_clicked)
self.ui.sino_button.clicked.connect(self.on_sino_button_clicked)
self.ui.proj_button.clicked.connect(self.on_proj_button_clicked)
self.ui.region_box.clicked.connect(self.on_region_box_clicked)
self.ui.method_box.currentIndexChanged.connect(self.change_method)
self.ui.axis_spin.valueChanged.connect(self.change_axis_spin)
self.ui.angle_step.valueChanged.connect(self.change_angle_step)
self.ui.output_path_button.clicked.connect(self.on_output_path_clicked)
self.ui.ffc_box.clicked.connect(self.on_ffc_box_clicked)
self.ui.interpolate_button.clicked.connect(self.on_interpolate_button_clicked)
self.ui.darks_path_button.clicked.connect(self.on_darks_path_clicked)
self.ui.flats_path_button.clicked.connect(self.on_flats_path_clicked)
self.ui.flats2_path_button.clicked.connect(self.on_flats2_path_clicked)
self.ui.ffc_options.currentIndexChanged.connect(self.change_ffc_options)
self.ui.reco_button.clicked.connect(self.on_reconstruct)
self.ui.path_button_0.clicked.connect(self.on_path_0_clicked)
self.ui.path_button_180.clicked.connect(self.on_path_180_clicked)
self.ui.show_slices_button.clicked.connect(self.on_show_slices_clicked)
self.ui.show_volume_button.clicked.connect(self.on_show_volume_clicked)
self.ui.run_button.clicked.connect(self.on_compute_center)
self.ui.save_action.triggered.connect(self.on_save_as)
self.ui.clear_action.triggered.connect(self.on_clear)
self.ui.clear_output_dir_action.triggered.connect(self.on_clear_output_dir_clicked)
self.ui.open_action.triggered.connect(self.on_open_from)
self.ui.close_action.triggered.connect(self.close)
self.ui.about_action.triggered.connect(self.on_about)
self.ui.extrema_checkbox.clicked.connect(self.on_remove_extrema_clicked)
self.ui.overlap_opt.currentIndexChanged.connect(self.on_overlap_opt_changed)
self.ui.input_path_line.textChanged.connect(self.on_input_path_changed)
self.ui.y_step.valueChanged.connect(lambda value: self.change_value('y_step', value))
self.ui.angle_offset.valueChanged.connect(lambda value: self.change_value('offset', value))
self.ui.oversampling.valueChanged.connect(lambda value: self.change_value('oversampling', value))
self.ui.iterations_sart.valueChanged.connect(lambda value: self.change_value('num_iterations', value))
self.ui.relaxation.valueChanged.connect(lambda value: self.change_value('relaxation_factor', value))
self.ui.output_path_line.textChanged.connect(lambda value: self.change_value('output', str(self.ui.output_path_line.text())))
self.ui.darks_path_line.textChanged.connect(lambda value: self.change_value('darks', str(self.ui.darks_path_line.text())))
self.ui.flats_path_line.textChanged.connect(lambda value: self.change_value('flats', str(self.ui.flats_path_line.text())))
self.ui.flats2_path_line.textChanged.connect(lambda value: self.change_value('flats2', str(self.ui.flats2_path_line.text())))
self.ui.fix_naninf_box.clicked.connect(lambda value: self.change_value('fix_nan_and_inf', self.ui.fix_naninf_box.isChecked()))
self.ui.absorptivity_box.clicked.connect(lambda value: self.change_value('absorptivity', self.ui.absorptivity_box.isChecked()))
self.ui.path_line_0.textChanged.connect(lambda value: self.change_value('deg0', str(self.ui.path_line_0.text())))
self.ui.path_line_180.textChanged.connect(lambda value: self.change_value('deg180', str(self.ui.path_line_180.text())))
self.ui.overlap_layout.addWidget(self.overlap_viewer)
self.overlap_viewer.slider.valueChanged.connect(self.on_axis_slider_changed)
def on_log_record(self, record):
self.ui.text_browser.append(record)
def get_values_from_params(self):
self.ui.input_path_line.setText(self.params.sinograms or self.params.projections or '.')
self.ui.output_path_line.setText(self.params.output or '')
self.ui.darks_path_line.setText(self.params.darks or '')
self.ui.flats_path_line.setText(self.params.flats or '')
self.ui.flats2_path_line.setText(self.params.flats2 or '')
self.ui.path_line_0.setText(self.params.deg0)
self.ui.path_line_180.setText(self.params.deg180)
self.ui.y_step.setValue(self.params.y_step if self.params.y_step else 1)
self.ui.axis_spin.setValue(self.params.axis if self.params.axis else 0.0)
self.ui.angle_step.setValue(self.params.angle if self.params.angle else 0.0)
self.ui.angle_offset.setValue(self.params.offset if self.params.offset else 0.0)
self.ui.oversampling.setValue(self.params.oversampling if self.params.oversampling else 0)
self.ui.iterations_sart.setValue(self.params.num_iterations if
self.params.num_iterations else 0)
self.ui.relaxation.setValue(self.params.relaxation_factor if
self.params.relaxation_factor else 0.0)
if self.params.projections is not None:
self.ui.proj_button.setChecked(True)
self.ui.sino_button.setChecked(False)
self.on_proj_button_clicked()
else:
self.ui.proj_button.setChecked(False)
self.ui.sino_button.setChecked(True)
self.on_sino_button_clicked()
if self.params.method == "fbp":
self.ui.method_box.setCurrentIndex(0)
elif self.params.method == "dfi":
self.ui.method_box.setCurrentIndex(1)
elif self.params.method == "sart":
self.ui.method_box.setCurrentIndex(2)
self.change_method()
if self.params.y_step > 1 and self.sino_button.isChecked():
self.ui.region_box.setChecked(True)
else:
self.ui.region_box.setChecked(False)
self.ui.on_region_box_clicked()
ffc_enabled = bool(self.params.flats) and bool(self.params.darks) and self.proj_button.isChecked()
self.ui.ffc_box.setChecked(ffc_enabled)
self.ui.preprocessing_container.setVisible(ffc_enabled)
self.ui.interpolate_button.setChecked(bool(self.params.flats2) and ffc_enabled)
self.ui.fix_naninf_box.setChecked(self.params.fix_nan_and_inf)
self.ui.absorptivity_box.setChecked(self.params.absorptivity)
if self.params.reduction_mode.lower() == "average":
self.ui.ffc_options.setCurrentIndex(0)
else:
self.ui.ffc_options.setCurrentIndex(1)
def change_method(self):
self.params.method = str(self.ui.method_box.currentText()).lower()
is_dfi = self.params.method == 'dfi'
is_sart = self.params.method == 'sart'
for w in (self.ui.oversampling_label, self.ui.oversampling):
w.setVisible(is_dfi)
for w in (self.ui.relaxation, self.ui.relaxation_label,
self.ui.iterations_sart, self.ui.iterations_sart_label):
w.setVisible(is_sart)
def get_help(self, section, name):
help = config.SECTIONS[section][name]['help']
return help
def change_value(self, name, value):
setattr(self.params, name, value)
def on_sino_button_clicked(self):
self.on_input_path_changed()
self.ui.ffc_box.setEnabled(False)
self.ui.preprocessing_container.setVisible(False)
def on_proj_button_clicked(self):
self.on_input_path_changed()
self.ui.ffc_box.setEnabled(True)
self.ui.preprocessing_container.setVisible(self.ffc_box.isChecked())
self.ui.region_box.setEnabled(False)
self.ui.region_box.setChecked(False)
self.on_region_box_clicked()
def on_region_box_clicked(self):
self.ui.y_step.setEnabled(self.ui.region_box.isChecked())
if self.ui.region_box.isChecked():
self.params.y_step = self.ui.y_step.value()
else:
self.params.y_step = 1
def on_input_path_changed(self):
if self.ui.sino_button.isChecked():
self.params.sinograms = str(self.ui.input_path_line.text())
self.params.projections = None
else:
self.params.sinograms = None
self.params.projections = str(self.ui.input_path_line.text())
def on_input_path_clicked(self, checked):
directory = self.params.projections or self.params.sinograms
path = self.get_path(directory, self.params.last_dir)
self.params.last_dir = set_last_dir(path, self.ui.input_path_line, self.params.last_dir)
def change_axis_spin(self):
if self.ui.axis_spin.value() == 0:
self.params.axis = None
else:
self.params.axis = self.ui.axis_spin.value()
def change_angle_step(self):
if self.ui.angle_step.value() == 0:
self.params.angle = None
else:
self.params.angle = self.ui.angle_step.value()
def on_output_path_clicked(self, checked):
path = self.get_path(self.params.output, self.params.last_dir)
self.params.last_dir = set_last_dir(path, self.ui.output_path_line, self.params.last_dir)
def on_clear_output_dir_clicked(self):
with spinning_cursor():
output_absfiles = get_filtered_filenames(str(self.ui.output_path_line.text()))
for f in output_absfiles:
os.remove(f)
def on_ffc_box_clicked(self):
checked = self.ui.ffc_box.isChecked()
self.ui.preprocessing_container.setVisible(checked)
self.params.ffc_correction = checked
def on_interpolate_button_clicked(self):
checked = self.ui.interpolate_button.isChecked()
self.ui.flats2_path_line.setEnabled(checked)
self.ui.flats2_path_button.setEnabled(checked)
def change_ffc_options(self):
self.params.reduction_mode = str(self.ui.ffc_options.currentText()).lower()
def on_darks_path_clicked(self, checked):
path = self.get_path(self.params.darks, self.params.last_dir)
self.params.last_dir = set_last_dir(path, self.ui.darks_path_line, self.params.last_dir)
def on_flats_path_clicked(self, checked):
path = self.get_path(self.params.flats, self.params.last_dir)
self.params.last_dir = set_last_dir(path, self.ui.flats_path_line, self.params.last_dir)
def on_flats2_path_clicked(self, checked):
path = self.get_path(self.params.flats2, self.params.last_dir)
self.params.last_dir = set_last_dir(path, self.ui.flats2_path_line, self.params.last_dir)
def get_path(self, directory, last_dir):
return QtGui.QFileDialog.getExistingDirectory(self, '.', last_dir or directory)
def get_filename(self, directory, last_dir):
return QtGui.QFileDialog.getOpenFileName(self, '.', last_dir or directory)
def on_path_0_clicked(self, checked):
path = self.get_filename(self.params.deg0, self.params.last_dir)
self.params.last_dir = set_last_dir(path, self.ui.path_line_0, self.params.last_dir)
def on_path_180_clicked(self, checked):
path = self.get_filename(self.params.deg180, self.params.last_dir)
self.params.last_dir = set_last_dir(path, self.ui.path_line_180, self.params.last_dir)
def on_open_from(self):
config_file = QtGui.QFileDialog.getOpenFileName(self, 'Open ...', self.params.last_dir)
parser = ArgumentParser()
params = config.Params(sections=config.TOMO_PARAMS + ('gui',))
parser = params.add_arguments(parser)
self.params = parser.parse_known_args(config.config_to_list(config_name=config_file))[0]
self.get_values_from_params()
def on_about(self):
message = "GUI is part of ufo-reconstruct {}.".format(__version__)
QtGui.QMessageBox.about(self, "About ufo-reconstruct", message)
def on_save_as(self):
if os.path.exists(self.params.last_dir):
config_file = str(self.params.last_dir + "/reco.conf")
else:
config_file = str(os.getenv('HOME') + "reco.conf")
save_config = QtGui.QFileDialog.getSaveFileName(self, 'Save as ...', config_file)
if save_config:
sections = config.TOMO_PARAMS + ('gui',)
config.write(save_config, args=self.params, sections=sections)
def on_clear(self):
self.ui.axis_view_widget.setVisible(False)
self.ui.input_path_line.setText('.')
self.ui.output_path_line.setText('.')
self.ui.darks_path_line.setText('.')
self.ui.flats_path_line.setText('.')
self.ui.flats2_path_line.setText('.')
self.ui.path_line_0.setText('.')
self.ui.path_line_180.setText('.')
self.ui.fix_naninf_box.setChecked(True)
self.ui.absorptivity_box.setChecked(True)
self.ui.sino_button.setChecked(True)
self.ui.proj_button.setChecked(False)
self.ui.region_box.setChecked(False)
self.ui.ffc_box.setChecked(False)
self.ui.interpolate_button.setChecked(False)
self.ui.y_step.setValue(1)
self.ui.axis_spin.setValue(0)
self.ui.angle_step.setValue(0)
self.ui.angle_offset.setValue(0)
self.ui.oversampling.setValue(0)
self.ui.ffc_options.setCurrentIndex(0)
self.ui.text_browser.clear()
self.ui.method_box.setCurrentIndex(0)
self.params.enable_cropping = False
self.params.reduction_mode = "average"
self.params.fix_nan_and_inf = True
self.params.absorptivity = True
self.params.show_2d = False
self.params.show_3d = False
self.params.angle = None
self.params.axis = None
self.on_region_box_clicked()
self.on_ffc_box_clicked()
self.on_interpolate_button_clicked()
def on_reconstruct(self):
with spinning_cursor():
self.ui.centralWidget.setEnabled(False)
self.repaint()
self.app.processEvents()
input_images = get_filtered_filenames(str(self.ui.input_path_line.text()))
if not input_images:
self.gui_warn("No data found in {}".format(str(self.ui.input_path_line.text())))
self.ui.centralWidget.setEnabled(True)
return
shape = util.get_image_shape(input_images[0])
self.params.width = shape[-1]
self.params.height = shape[-2]
self.params.ffc_correction = self.params.ffc_correction and self.ui.proj_button.isChecked()
if not (self.params.output.endswith('.tif') or
self.params.output.endswith('.tiff')):
self.params.output = os.path.join(self.params.output, 'slice-%05i.tif')
if self.params.y_step > 1:
self.params.angle *= self.params.y_step
if self.params.ffc_correction:
flats_files = get_filtered_filenames(str(self.ui.flats_path_line.text()))
self.params.num_flats = len(flats_files)
else:
self.params.num_flats = 0
self.params.darks = None
self.params.flats = None
self.params.flats2 = self.ui.flats2_path_line.text() if self.ui.interpolate_button.isChecked() else ''
self.params.oversampling = self.ui.oversampling.value() if self.params.method == 'dfi' else None
if self.params.method == 'sart':
self.params.max_iterations = self.ui.iterations_sart.value()
self.params.relaxation_factor = self.ui.relaxation.value()
if self.params.angle is None:
self.gui_warn("Missing argument for Angle step (rad)")
else:
try:
reco.tomo(self.params)
except Exception as e:
self.gui_warn(str(e))
self.ui.centralWidget.setEnabled(True)
self.params.angle = self.ui.angle_step.value()
def on_show_slices_clicked(self):
path = str(self.ui.output_path_line.text())
filenames = get_filtered_filenames(path)
if not self.slice_viewer:
self.slice_viewer = tofu.vis.qt.ImageViewer(filenames)
self.slice_dock.setWidget(self.slice_viewer)
self.ui.slice_dock.setVisible(True)
else:
self.slice_viewer.load_files(filenames)
def on_show_volume_clicked(self):
if not self.volume_viewer:
step = int(self.ui.reduction_box.currentText())
self.volume_viewer = tofu.vis.qt.VolumeViewer(parent=self, step=step)
self.volume_dock.setWidget(self.volume_viewer)
self.ui.volume_dock.setVisible(True)
path = str(self.ui.output_path_line.text())
filenames = get_filtered_filenames(path)
self.volume_viewer.load_files(filenames)
def on_compute_center(self):
first_name = str(self.ui.path_line_0.text())
second_name = str(self.ui.path_line_180.text())
with tifffile.TiffFile(first_name) as tif:
first = tif.pages[0].asarray().astype(np.float)
with tifffile.TiffFile(second_name) as tif:
second = tif.pages[-1].asarray().astype(np.float)
if self.params.ffc_correction:
# FIXME: we should of course use the pipelines we have ...
flat_files = get_filtered_filenames(str(self.ui.flats_path_line.text()))
dark_files = get_filtered_filenames(str(self.ui.darks_path_line.text()))
flats = np.array([tifffile.TiffFile(x).asarray().astype(np.float) for x in flat_files])
darks = np.array([tifffile.TiffFile(x).asarray().astype(np.float) for x in dark_files])
dark = np.mean(darks, axis=0)
flat = np.mean(flats, axis=0) - dark
first = (first - dark) / flat
second = (second - dark) / flat
self.axis = reco.compute_rotation_axis(first, second)
self.height, self.width = first.shape
w2 = self.width / 2.0
position = w2 + (w2 - self.axis) * 2.0
self.overlap_viewer.set_images(first, second)
self.overlap_viewer.set_position(position)
self.ui.img_size.setText('width = {} | height = {}'.format(self.width, self.height))
def on_remove_extrema_clicked(self, val):
self.ui.overlap_viewer.remove_extrema = val
def on_overlap_opt_changed(self, index):
self.ui.overlap_viewer.subtract = index == 0
self.ui.overlap_viewer.update_image()
def on_axis_slider_changed(self):
val = self.overlap_viewer.slider.value()
w2 = self.width / 2.0
self.axis = w2 + (w2 - val) / 2
self.ui.axis_num.setText('{} px'.format(self.axis))
self.ui.axis_spin.setValue(self.axis)
def gui_warn(self, message):
QtGui.QMessageBox.warning(self, "Warning", message)
def main(params):
app = QtGui.QApplication(sys.argv)
ApplicationWindow(app, params)
sys.exit(app.exec_())
tofu-0.12.0/tofu/gui.ui 0000664 0000000 0000000 00000137776 14237137211 0014675 0 ustar 00root root 0000000 0000000
mainWindow
0
0
1018
1081
0
0
Tomoviewer
true
0
0
541
761
-
0
0
530
0
PreferDefault
true
Qt::TabFocus
Qt::LeftToRight
1
true
0
0
0
0
Reconstruction
-
Input
-
-
0
0
Projections
true
-
0
0
Sinograms
true
true
-
false
0
0
1
500
-
0
0
false
false
Region (y-step):
-
Qt::Horizontal
40
20
-
0
0
Do flat-field correction
-
-
0
0
Path:
-
0
0
-
0
0
Browse …
-
0
0
Flat-field correction
-
-
0
0
-
Average
-
Median
-
Options
Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter
-
-
Use absorptivity
-
Remove NaN and Inf
-
0
0
Interpolate
-
Qt::Horizontal
40
20
-
Method:
Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter
-
true
Darks:
Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter
-
-
true
-
true
Browse …
-
true
0
Flats:
Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter
-
Last flats:
Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter
-
-
true
-
true
Browse …
-
-
-
Browse …
-
Reconstruction
-
6
-
0
0
Angle step (rad):
Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter
-
0
0
50
false
Method:
Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter
-
0
0
50
false
-
FBP
-
DFI
-
SART
-
0
0
10
-
Qt::Horizontal
40
20
-
0
0
Angle offset (rad):
Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter
-
0
0
Reconstruct
-
0
0
Axis (pixel):
Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter
-
0
0
8192.000000000000000
-
0
0
10
-
-
0
0
Max iterations:
Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter
-
0
-
0
0
Relaxation factor:
-
0.000000000000000
-
Qt::Horizontal
40
20
-
true
99
0
-
true
0
0
Oversampling:
Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter
-
Output
-
-
0
0
Path:
-
0
0
-
0
0
Browse ...
-
-
Reduction:
-
1
-
1
-
2
-
4
-
8
-
Qt::Horizontal
40
20
-
Show Volume
-
Show Slices
-
0
0
Log
-
0
0
QFrame::Sunken
0
0
0
Center of rotation
-
-
0
0
Input
-
-
Options:
Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter
-
Method:
Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter
-
0
0
Browse ...
-
0
0
180° projection:
Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter
-
0
0
Browse ...
-
0
0
In case of multi-page input, last image in the file is used
-
Remove extrema
-
0
0
In case of multi-page input, first image in the file is used
-
0
0
0° projection:
Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter
-
0
0
-
Subtraction overlap
-
Addition overlap
-
0
0
Run
-
Output
-
-
0
0
Center:
Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter
-
0
0
Size:
Qt::AlignRight|Qt::AlignTrailing|Qt::AlignVCenter
-
0
0
-
0
0
75
true
-
-
0
0
0
0
0
0
0
QDockWidget::DockWidgetFloatable|QDockWidget::DockWidgetMovable
2
0
0
QDockWidget::DockWidgetFloatable|QDockWidget::DockWidgetMovable
2
Open
Save as...
Open ...
Qt::WindowShortcut
Save as ...
Quit
Ctrl+Q
Clear
Remove old slices
Remove old slices in output directory
About
tofu-0.12.0/tofu/lamino.py 0000664 0000000 0000000 00000021245 14237137211 0015362 0 ustar 00root root 0000000 0000000 """Laminographic reconstruction."""
import logging
import numpy as np
from multiprocessing import Queue, Process
from tofu.preprocess import create_preprocessing_pipeline
from tofu.util import (get_filtering_padding, determine_shape, get_filenames,
get_reconstruction_regions, get_reconstructed_cube_shape)
from tofu.tasks import get_task, get_writer
LOG = logging.getLogger(__name__)
def lamino(params):
"""Laminographic reconstruction utilizing all GPUs."""
LOG.info('Z parameter: {}'.format(params.z_parameter))
prepare_angular_arguments(params)
params.projection_filter_scale = np.sin(np.deg2rad(params.lamino_angle))
# For now we need to make a workaround for the memory leak, which means we need to execute
# the passes in separate processes to clean up the low level code. For that we also need to
# call the region-splitting in a separate function.
# TODO: Simplify after the memory leak fix!
queue = Queue()
proc = Process(target=_create_runs, args=(params, queue,))
proc.start()
proc.join()
x_region, y_region, regions, num_gpus = queue.get()
for i in range(0, len(regions), num_gpus):
z_subregion = regions[i:min(i + num_gpus, len(regions))]
LOG.info('Computing slices {}..{}'.format(z_subregion[0][0], z_subregion[-1][1]))
proc = Process(target=_run, args=(params, x_region, y_region, z_subregion, i // num_gpus))
proc.start()
proc.join()
def prepare_angular_arguments(params):
if not params.overall_angle:
params.overall_angle = 360.
LOG.info('Overall angle not specified, using 360 deg')
if not params.angle:
if params.dry_run:
if not params.number:
raise ValueError('--number must be specified by --dry-run')
num_files = params.number
else:
num_files = len(get_filenames(params.projections))
if not num_files:
raise RuntimeError("No files found in `{}'".format(params.projections))
params.angle = params.overall_angle / num_files * params.step
LOG.info('Angle not specified, calculating from ' +
'{} projections and step {}: {} deg'.format(num_files, params.step,
params.angle))
determine_shape(params, params.projections, store=True)
if not params.number:
params.number = int(np.round(np.abs(params.overall_angle / params.angle)))
if params.dry_run:
LOG.info('Dummy data W x H x N: {} x {} x {}'.format(params.width,
params.height,
params.number))
def _create_runs(params, queue):
"""Workaround function to get the number of gpus and compute regions. gi.repository must always
be called in a separate process, otherwise the resources return None gpus.
"""
#TODO: remove the whole function after memory leak fix!
from gi.repository import Ufo
scheduler = Ufo.FixedScheduler()
gpus = scheduler.get_resources().get_gpu_nodes()
num_gpus = len(gpus)
x_region, y_region, regions = _split_regions(params, gpus)
LOG.info('Using {} GPUs in {} passes'.format(min(len(regions), num_gpus), len(regions)))
queue.put((x_region, y_region, regions, num_gpus))
def _run(params, x_region, y_region, regions, index):
"""Execute one pass on all possible GPUs with slice ranges given by *regions*."""
from gi.repository import Ufo
pm = Ufo.PluginManager()
graph = Ufo.TaskGraph()
scheduler = Ufo.FixedScheduler()
gpus = scheduler.get_resources().get_gpu_nodes()
num_gpus = len(gpus)
broadcast = Ufo.CopyTask()
source = _setup_source(params, pm, graph)
graph.connect_nodes(source, broadcast)
for i, region in enumerate(regions):
subindex = index * num_gpus + i
_setup_graph(pm, graph, subindex, x_region, y_region, region,
params, broadcast, gpu=gpus[i])
scheduler.run(graph)
duration = scheduler.props.time
LOG.info('Execution time: {} s'.format(duration))
return duration
def _setup_source(params, pm, graph):
from tofu.preprocess import create_flat_correct_pipeline
from tofu.util import set_node_props, setup_read_task
if params.dry_run:
source = pm.get_task('dummy-data')
source.props.number = params.number
source.props.width = params.width
source.props.height = params.height
elif params.darks and params.flats:
source = create_flat_correct_pipeline(params, graph)
else:
source = pm.get_task('read')
set_node_props(source, params)
setup_read_task(source, params.projections, params)
return source
def _setup_graph(pm, graph, index, x_region, y_region, region, params, source, gpu=None):
backproject = get_task('lamino-backproject', processing_node=gpu)
slicer = get_task('slice', processing_node=gpu)
writer = get_writer(params)
if not params.dry_run:
writer.props.filename = '{}-{:>03}-%04i.tif'.format(params.output, index)
# parameters
backproject.props.num_projections = params.number
backproject.props.overall_angle = np.deg2rad(params.overall_angle)
backproject.props.lamino_angle = np.deg2rad(params.lamino_angle)
backproject.props.roll_angle = np.deg2rad(params.roll_angle)
backproject.props.x_region = x_region
backproject.props.y_region = y_region
backproject.props.z = params.z
backproject.props.addressing_mode = params.lamino_padding_mode
backproject.props.parameter = params.z_parameter
if params.projection_crop_after == 'backprojection':
padding = get_filtering_padding(params.width)
else:
padding = 0
if params.z_parameter in ['lamino-angle', 'roll-angle']:
region = [np.deg2rad(reg) for reg in region]
if params.z_parameter == 'x-center':
# Take projection padding into account
region = [region[0] + padding / 2, region[1] + padding / 2, region[2]]
backproject.props.region = region
backproject.props.center = (params.axis[0] + padding / 2, params.axis[1])
LOG.debug('x center after padding: %g', backproject.props.center[0])
graph.connect_nodes(backproject, slicer)
graph.connect_nodes(slicer, writer)
if params.only_bp:
first = backproject
graph.connect_nodes(source, backproject)
else:
first = create_preprocessing_pipeline(params, graph, source=source, processing_node=gpu)
graph.connect_nodes(first, backproject)
return first
def _split_regions(params, gpus):
"""Split processing between *gpus* by specifying the number of slices processed per GPU."""
x_region, y_region, z_region = get_reconstruction_regions(params)
z_start, z_stop, z_step = z_region
y_start, y_stop, y_step = y_region
x_start, x_stop, x_step = x_region
slice_width, slice_height, num_slices = get_reconstructed_cube_shape(x_region, y_region,
z_region)
if params.slices_per_device:
num_slices_per_gpu = params.slices_per_device
else:
num_slices_per_gpu = _compute_num_slices(gpus, slice_width, slice_height)
if num_slices_per_gpu > num_slices:
num_slices_per_gpu = num_slices
LOG.info('Using {} slices per GPU'.format(num_slices_per_gpu))
z_starts = np.arange(z_start, z_stop, z_step * num_slices_per_gpu)
regions = []
for start in z_starts:
regions.append((start, min(z_stop, start + z_step * num_slices_per_gpu), z_step))
return x_region, y_region, regions
def _compute_num_slices(gpus, width, height):
"""Determine number of slices which can be calculated per-device based on *gpus*, slice *width*
and *height*.
"""
from gi.repository import Ufo
# Make sure the double buffering works with room for intermediate steps
# TODO: compute this precisely
safety_coeff = 3.
# Use the weakest one, if heterogenous systems emerge, measure the performance and
# reconsider
memories = [gpu.get_info(Ufo.GpuNodeInfo.GLOBAL_MEM_SIZE) for gpu in gpus]
i = np.argmin(memories)
max_allocatable = gpus[i].get_info(Ufo.GpuNodeInfo.MAX_MEM_ALLOC_SIZE)
if max_allocatable * safety_coeff <= memories[i]:
# Don't waste resources
max_memory = max_allocatable
else:
max_memory = memories[i] / safety_coeff
if max_memory > 2 ** 32:
# Current NVIDIA implementation allows only 4 GB
max_memory = 2 ** 32
max_memory /= safety_coeff
num_slices = int(np.floor(max_memory / (width * height * 4)))
LOG.info('GPU memory used per GPU: {:.2f} GB'.format(max_memory / 2. ** 30))
return num_slices
tofu-0.12.0/tofu/preprocess.py 0000664 0000000 0000000 00000035712 14237137211 0016274 0 ustar 00root root 0000000 0000000 """Flat field correction."""
import sys
import logging
from gi.repository import Ufo
from tofu.util import (get_filenames, set_node_props, make_subargs,
determine_shape, setup_read_task,
setup_padding, next_power_of_two)
from tofu.tasks import get_task, get_writer
LOG = logging.getLogger(__name__)
def create_flat_correct_pipeline(args, graph, processing_node=None):
"""
Create flat field correction pipeline. All the settings are provided in
*args*. *graph* is used for making the connections. Returns the flat field
correction task which can be used for further pipelining.
"""
pm = Ufo.PluginManager()
if args.projections is None or args.flats is None or args.darks is None:
raise RuntimeError("You must specify --projections, --flats and --darks.")
reader = get_task('read')
dark_reader = get_task('read')
flat_before_reader = get_task('read')
ffc = get_task('flat-field-correct', processing_node=processing_node,
dark_scale=args.dark_scale,
flat_scale=args.flat_scale,
absorption_correct=args.absorptivity,
fix_nan_and_inf=args.fix_nan_and_inf)
mode = args.reduction_mode.lower()
roi_args = make_subargs(args, ['y', 'height', 'y_step'])
set_node_props(reader, args)
set_node_props(dark_reader, roi_args)
set_node_props(flat_before_reader, roi_args)
for r, path in ((reader, args.projections), (dark_reader, args.darks), (flat_before_reader, args.flats)):
setup_read_task(r, path, args)
LOG.debug("Doing flat field correction using reduction mode `{}'".format(mode))
if args.flats2:
flat_after_reader = get_task('read')
setup_read_task(flat_after_reader, args.flats2, args)
set_node_props(flat_after_reader, roi_args)
num_files = len(get_filenames(args.projections))
can_read = len(list(range(args.start, num_files, args.step)))
number = args.number if args.number else num_files
num_read = min(can_read, number)
flat_interpolate = get_task('interpolate', processing_node=processing_node, number=num_read)
if args.resize:
LOG.debug("Resize input data by factor of {}".format(args.resize))
proj_bin = get_task('bin', processing_node=processing_node, size=args.resize)
dark_bin = get_task('bin', processing_node=processing_node, size=args.resize)
flat_bin = get_task('bin', processing_node=processing_node, size=args.resize)
graph.connect_nodes(reader, proj_bin)
graph.connect_nodes(dark_reader, dark_bin)
graph.connect_nodes(flat_before_reader, flat_bin)
reader, dark_reader, flat_before_reader = proj_bin, dark_bin, flat_bin
if args.flats2:
flat_bin = get_task('bin', processing_node=processing_node, size=args.resize)
graph.connect_nodes(flat_after_reader, flat_bin)
flat_after_reader = flat_bin
if mode == 'median':
dark_stack = get_task('stack', processing_node=processing_node,
number=len(get_filenames(args.darks)))
dark_reduced = get_task('flatten', processing_node=processing_node, mode='median')
flat_before_stack = get_task('stack', processing_node=processing_node,
number=len(get_filenames(args.flats)))
flat_before_reduced = get_task('flatten', processing_node=processing_node, mode='median')
graph.connect_nodes(dark_reader, dark_stack)
graph.connect_nodes(dark_stack, dark_reduced)
graph.connect_nodes(flat_before_reader, flat_before_stack)
graph.connect_nodes(flat_before_stack, flat_before_reduced)
if args.flats2:
flat_after_stack = get_task('stack', processing_node=processing_node,
number=len(get_filenames(args.flats2)))
flat_after_reduced = get_task('flatten', processing_node=processing_node,
mode='median')
graph.connect_nodes(flat_after_reader, flat_after_stack)
graph.connect_nodes(flat_after_stack, flat_after_reduced)
elif mode == 'average':
dark_reduced = get_task('average', processing_node=processing_node)
flat_before_reduced = get_task('average', processing_node=processing_node)
graph.connect_nodes(dark_reader, dark_reduced)
graph.connect_nodes(flat_before_reader, flat_before_reduced)
if args.flats2:
flat_after_reduced = get_task('average', processing_node=processing_node)
graph.connect_nodes(flat_after_reader, flat_after_reduced)
else:
raise ValueError('Invalid reduction mode')
graph.connect_nodes_full(reader, ffc, 0)
graph.connect_nodes_full(dark_reduced, ffc, 1)
if args.flats2:
graph.connect_nodes_full(flat_before_reduced, flat_interpolate, 0)
graph.connect_nodes_full(flat_after_reduced, flat_interpolate, 1)
graph.connect_nodes_full(flat_interpolate, ffc, 2)
else:
graph.connect_nodes_full(flat_before_reduced, ffc, 2)
return ffc
def create_phase_retrieval_pipeline(args, graph, processing_node=None):
LOG.debug('Creating phase retrieval pipeline')
pm = Ufo.PluginManager()
# Retrieve phase
phase_retrieve = get_task('retrieve-phase', processing_node=processing_node)
pad_phase_retrieve = get_task('pad', processing_node=processing_node)
crop_phase_retrieve = get_task('crop', processing_node=processing_node)
fft_phase_retrieve = get_task('fft', processing_node=processing_node)
ifft_phase_retrieve = get_task('ifft', processing_node=processing_node)
last = crop_phase_retrieve
width = args.width
height = args.height
default_padded_width = next_power_of_two(width + 64)
default_padded_height = next_power_of_two(height + 64)
if not args.retrieval_padded_width:
args.retrieval_padded_width = default_padded_width
if not args.retrieval_padded_height:
args.retrieval_padded_height = default_padded_height
fmt = 'Phase retrieval padding: {}x{} -> {}x{}'
LOG.debug(fmt.format(width, height, args.retrieval_padded_width,
args.retrieval_padded_height))
x = (args.retrieval_padded_width - width) // 2
y = (args.retrieval_padded_height - height) // 2
pad_phase_retrieve.props.x = x
pad_phase_retrieve.props.y = y
pad_phase_retrieve.props.width = args.retrieval_padded_width
pad_phase_retrieve.props.height = args.retrieval_padded_height
pad_phase_retrieve.props.addressing_mode = args.retrieval_padding_mode
crop_phase_retrieve.props.x = x
crop_phase_retrieve.props.y = y
crop_phase_retrieve.props.width = width
crop_phase_retrieve.props.height = height
phase_retrieve.props.method = args.retrieval_method
phase_retrieve.props.energy = args.energy
if len(args.propagation_distance) == 1:
phase_retrieve.props.distance = [args.propagation_distance[0]]
else:
phase_retrieve.props.distance_x = args.propagation_distance[0]
phase_retrieve.props.distance_y = args.propagation_distance[1]
phase_retrieve.props.pixel_size = args.pixel_size
phase_retrieve.props.regularization_rate = args.regularization_rate
phase_retrieve.props.thresholding_rate = args.thresholding_rate
phase_retrieve.props.frequency_cutoff = args.frequency_cutoff
fft_phase_retrieve.props.dimensions = 2
ifft_phase_retrieve.props.dimensions = 2
graph.connect_nodes(pad_phase_retrieve, fft_phase_retrieve)
graph.connect_nodes(fft_phase_retrieve, phase_retrieve)
graph.connect_nodes(phase_retrieve, ifft_phase_retrieve)
graph.connect_nodes(ifft_phase_retrieve, crop_phase_retrieve)
calculate = get_task('calculate', processing_node=processing_node)
if args.delta is not None:
import numpy as np
lam = 6.62606896e-34 * 299792458 / (args.energy * 1.60217733e-16)
thickness_conversion = -lam / (2 * np.pi * args.delta)
else:
thickness_conversion = 1
if args.retrieval_method == 'tie':
expression = '(isinf (v) || isnan (v) || (v <= 0)) ? 0.0f : -log ({} * v) * {{}}'
# 2 for 0.5 factor in ufo-filters and alpha = 10^-R, so divide by 10^R
expression = expression.format(2 / 10 ** args.regularization_rate)
# The following converts the TIE result to the actual phase, which when multiplied by the
# thickness_conversion gives the projected thickness
thickness_conversion *= -10 ** args.regularization_rate / 2
expression = expression.format(thickness_conversion)
else:
expression = '(isinf (v) || isnan (v)) ? 0.0f : v * {}'.format(thickness_conversion)
calculate.props.expression = expression
graph.connect_nodes(crop_phase_retrieve, calculate)
last = calculate
return (pad_phase_retrieve, last)
def run_flat_correct(args):
graph = Ufo.TaskGraph()
sched = Ufo.Scheduler()
pm = Ufo.PluginManager()
out_task = get_writer(args)
flat_task = create_flat_correct_pipeline(args, graph)
graph.connect_nodes(flat_task, out_task)
sched.run(graph)
def create_sinogram_pipeline(args, graph):
"""Create sinogram generating pipeline based on arguments from *args*."""
pm = Ufo.PluginManager()
sinos = pm.get_task('transpose-projections')
if args.number:
region = (args.start, args.start + args.number, args.step)
num_projections = len(list(range(*region)))
else:
num_projections = len(get_filenames(args.projections))
sinos.props.number = num_projections
if args.darks and args.flats:
start = create_flat_correct_pipeline(args, graph)
else:
start = get_task('read')
start.props.path = args.projections
set_node_props(start, args)
graph.connect_nodes(start, sinos)
return sinos
def run_sinogram_generation(args):
"""Make the sinograms with arguments provided by *args*."""
if not args.height:
args.height = determine_shape(args, args.projections)[1] - args.y
step = args.y_step * args.pass_size if args.pass_size else args.height
starts = list(range(args.y, args.y + args.height, step)) + [args.y + args.height]
def generate_partial(append=False):
graph = Ufo.TaskGraph()
sched = Ufo.Scheduler()
args.output_append = append
writer = get_writer(args)
sinos = create_sinogram_pipeline(args, graph)
graph.connect_nodes(sinos, writer)
sched.run(graph)
for i in range(len(starts) - 1):
args.y = starts[i]
args.height = starts[i + 1] - starts[i]
generate_partial(append=i != 0)
def create_projection_filtering_pipeline(args, graph, processing_node=None):
pm = Ufo.PluginManager()
pad = get_task('pad', processing_node=processing_node)
fft = get_task('fft', processing_node=processing_node)
ifft = get_task('ifft', processing_node=processing_node)
fltr = get_task('filter', processing_node=processing_node)
if args.projection_crop_after == 'filter':
crop = get_task('crop', processing_node=processing_node)
else:
crop = None
padding_width = setup_padding(pad, args.width, args.height, args.projection_padding_mode,
crop=crop)[0]
fft.props.dimensions = 1
ifft.props.dimensions = 1
fltr.props.filter = args.projection_filter
fltr.props.scale = args.projection_filter_scale
fltr.props.cutoff = args.projection_filter_cutoff
graph.connect_nodes(pad, fft)
graph.connect_nodes(fft, fltr)
graph.connect_nodes(fltr, ifft)
if crop:
graph.connect_nodes(ifft, crop)
last = crop
else:
last = ifft
return (pad, last)
def create_preprocessing_pipeline(args, graph, source=None, processing_node=None,
cone_beam_weight=True, make_reader=True):
"""If *make_reader* is True, create a read task if *source* is None and no dark and flat fields
are given.
"""
import numpy as np
if not (args.width and args.height):
width, height = determine_shape(args, args.projections)
if not width:
raise RuntimeError("Could not determine width from the input")
if not args.width:
args.width = width
if not args.height:
args.height = height - args.y
LOG.debug('Image width x height: %d x %d', args.width, args.height)
current = None
if source:
current = source
elif args.darks and args.flats:
current = create_flat_correct_pipeline(args, graph, processing_node=processing_node)
else:
if make_reader:
current = get_task('read')
set_node_props(current, args)
if not args.projections:
raise RuntimeError('--projections not set')
setup_read_task(current, args.projections, args)
if args.absorptivity:
absorptivity = get_task('calculate', processing_node=processing_node)
absorptivity.props.expression = 'v <= 0 ? 0.0f : -log(v)'
if current:
graph.connect_nodes(current, absorptivity)
current = absorptivity
if args.transpose_input:
transpose = get_task('transpose')
if current:
graph.connect_nodes(current, transpose)
current = transpose
tmp = args.width
args.width = args.height
args.height = tmp
if cone_beam_weight and not np.all(np.isinf(args.source_position_y)):
# Cone beam projection weight
LOG.debug('Enabling cone beam weighting')
weight = get_task('cone-beam-projection-weight', processing_node=processing_node)
weight.props.source_distance = (-np.array(args.source_position_y)).tolist()
weight.props.detector_distance = args.detector_position_y
weight.props.center_position_x = args.center_position_x or [args.width / 2. + (args.width % 2) * 0.5]
weight.props.center_position_z = args.center_position_z or [args.height / 2. + (args.height % 2) * 0.5]
weight.props.axis_angle_x = args.axis_angle_x
if current:
graph.connect_nodes(current, weight)
current = weight
if args.energy is not None and args.propagation_distance is not None:
pr_first, pr_last = create_phase_retrieval_pipeline(args, graph,
processing_node=processing_node)
if current:
graph.connect_nodes(current, pr_first)
current = pr_last
if args.projection_filter != 'none':
pf_first, pf_last = create_projection_filtering_pipeline(args, graph,
processing_node=processing_node)
if current:
graph.connect_nodes(current, pf_first)
current = pf_last
return current
def run_preprocessing(args):
graph = Ufo.TaskGraph()
sched = Ufo.Scheduler()
pm = Ufo.PluginManager()
out_task = get_writer(args)
current = create_preprocessing_pipeline(args, graph)
graph.connect_nodes(current, out_task)
sched.run(graph)
tofu-0.12.0/tofu/reco.py 0000664 0000000 0000000 00000026741 14237137211 0015041 0 ustar 00root root 0000000 0000000 import os
import logging
import glob
import tempfile
import sys
import numpy as np
from gi.repository import Ufo
from tofu.preprocess import create_flat_correct_pipeline
from tofu.util import (set_node_props, setup_read_task, get_filenames,
read_image, determine_shape, setup_padding)
from tofu.tasks import get_task, get_writer
LOG = logging.getLogger(__name__)
pm = Ufo.PluginManager()
def get_dummy_reader(params):
if params.width is None and params.height is None:
raise RuntimeError("You have to specify --width and --height when generating data.")
width, height = params.width, params.height
reader = get_task('dummy-data', width=width, height=height, number=params.number or 1)
return reader, width, height
def get_file_reader(params):
reader = pm.get_task('read')
set_node_props(reader, params)
return reader
def get_projection_reader(params):
reader = get_file_reader(params)
setup_read_task(reader, params.projections, params)
width, height = determine_shape(params, params.projections)
return reader, width, height
def get_sinogram_reader(params):
reader = get_file_reader(params)
setup_read_task(reader, params.sinograms, params)
width, height = determine_shape(params, path=params.sinograms)
return reader, width, height
def tomo(params):
# Create reader and writer
if params.projections and params.sinograms:
raise RuntimeError("Cannot specify both --projections and --sinograms.")
if params.projections is None and params.sinograms is None:
reader, width, height = get_dummy_reader(params)
else:
if params.projections:
reader, width, height = get_projection_reader(params)
else:
reader, width, height = get_sinogram_reader(params)
axis = params.axis or width / 2.0
if params.projections and params.resize:
width /= params.resize
height /= params.resize
axis /= params.resize
LOG.debug("Input dimensions: {}x{} pixels".format(width, height))
writer = get_writer(params)
# Setup graph depending on the chosen method and input data
g = Ufo.TaskGraph()
if params.projections is not None:
if params.number:
count = len(list(range(params.start, params.start + params.number, params.step)))
else:
count = len(get_filenames(params.projections))
LOG.debug("Number of projections: {}".format(count))
sino_output = get_task('transpose-projections', number=count)
if params.darks and params.flats:
g.connect_nodes(create_flat_correct_pipeline(params, g), sino_output)
else:
g.connect_nodes(reader, sino_output)
if height:
# Sinogram height is the one needed for further padding
height = count
else:
sino_output = reader
if params.method == 'fbp':
fft = get_task('fft', dimensions=1)
ifft = get_task('ifft', dimensions=1)
fltr = get_task('filter', filter=params.projection_filter,
cutoff=params.projection_filter_cutoff)
bp = get_task('backproject', axis_pos=axis)
last_node = bp
if params.angle:
bp.props.angle_step = params.angle
if params.offset:
bp.props.angle_offset = params.offset
if width and height:
# Pad the image with its extent to prevent reconstuction ring
pad = get_task('pad')
crop = get_task('crop')
if params.projection_crop_after == 'filter':
crop_after_filter = crop
else:
crop_after_filter = None
padding_width = setup_padding(pad, width, height, params.projection_padding_mode,
crop=crop_after_filter)[0]
LOG.debug("Padding input to: {}x{} pixels".format(pad.props.width, pad.props.height))
g.connect_nodes(sino_output, pad)
g.connect_nodes(pad, fft)
g.connect_nodes(fft, fltr)
g.connect_nodes(fltr, ifft)
if crop_after_filter:
g.connect_nodes(ifft, crop)
g.connect_nodes(crop, bp)
else:
bp.props.axis_pos = axis + padding_width / 2
crop.props.x = padding_width // 2
crop.props.y = padding_width // 2
crop.props.width = width
crop.props.height = width
g.connect_nodes(ifft, bp)
g.connect_nodes(bp, crop)
last_node = crop
else:
if params.crop_width:
ifft.props.crop_width = int(params.crop_width)
LOG.debug("Cropping to {} pixels".format(ifft.props.crop_width))
g.connect_nodes(sino_output, fft)
g.connect_nodes(fft, fltr)
g.connect_nodes(fltr, ifft)
g.connect_nodes(ifft, bp)
g.connect_nodes(last_node, writer)
if params.method in ('sart', 'sirt', 'sbtv', 'asdpocs'):
projector = pm.get_task_from_package('ir', 'parallel-projector')
projector.set_properties(model='joseph', is_forward=False)
projector.set_properties(axis_position=axis)
projector.set_properties(step=params.angle if params.angle else np.pi / 180.0)
method = pm.get_task_from_package('ir', params.method)
method.set_properties(projector=projector, num_iterations=params.num_iterations)
if params.method in ('sart', 'sirt'):
method.set_properties(relaxation_factor=params.relaxation_factor)
if params.method == 'asdpocs':
minimizer = pm.get_task_from_package('ir', 'sirt')
method.set_properties(df_minimizer=minimizer)
if params.method == 'sbtv':
# FIXME: the lambda keyword is preventing from the following
# assignment ...
# method.props.lambda = params.lambda
method.set_properties(mu=params.mu)
g.connect_nodes(sino_output, method)
g.connect_nodes(method, writer)
if params.method == 'dfi':
oversampling = params.oversampling or 1
pad = get_task('zeropad', center_of_rotation=axis, oversampling=oversampling)
fft = get_task('fft', dimensions=1, auto_zeropadding=0)
dfi = get_task('dfi-sinc')
ifft = get_task('ifft', dimensions=2)
swap_forward = get_task('swap-quadrants')
swap_backward = get_task('swap-quadrants')
if params.angle:
dfi.props.angle_step = params.angle
g.connect_nodes(sino_output, pad)
g.connect_nodes(pad, fft)
g.connect_nodes(fft, dfi)
g.connect_nodes(dfi, swap_forward)
g.connect_nodes(swap_forward, ifft)
g.connect_nodes(ifft, swap_backward)
if width:
crop = get_task('crop')
crop.set_properties(from_center=True, width=width, height=width)
g.connect_nodes(swap_backward, crop)
g.connect_nodes(crop, writer)
else:
g.connect_nodes(swap_backward, writer)
scheduler = Ufo.Scheduler()
if hasattr(scheduler.props, 'enable_tracing'):
LOG.debug("Use tracing: {}".format(params.enable_tracing))
scheduler.props.enable_tracing = params.enable_tracing
scheduler.run(g)
duration = scheduler.props.time
LOG.info("Execution time: {} s".format(duration))
return duration
def estimate_center(params):
if params.estimate_method == 'reconstruction':
axis = estimate_center_by_reconstruction(params)
else:
axis = estimate_center_by_correlation(params)
return axis
def estimate_center_by_reconstruction(params):
if params.projections is not None:
raise RuntimeError("Cannot estimate axis from projections")
sinos = sorted(glob.glob(os.path.join(params.sinograms, '*.tif')))
if not sinos:
raise RuntimeError("No sinograms found in {}".format(params.sinograms))
# Use a sinogram that probably has some interesting data
filename = sinos[len(sinos) // 2]
sinogram = read_image(filename)
initial_width = sinogram.shape[1]
m0 = np.mean(np.sum(sinogram, axis=1))
center = initial_width / 2.0
width = initial_width / 2.0
new_center = center
tmp_dir = tempfile.mkdtemp()
tmp_output = os.path.join(tmp_dir, 'slice-0.tif')
params.sinograms = filename
params.output = os.path.join(tmp_dir, 'slice-%i.tif')
def heaviside(A):
return (A >= 0.0) * 1.0
def get_score(guess, m0):
# Run reconstruction with new guess
params.axis = guess
tomo(params)
# Analyse reconstructed slice
result = read_image(tmp_output)
Q_IA = float(np.sum(np.abs(result)) / m0)
Q_IN = float(-np.sum(result * heaviside(-result)) / m0)
LOG.info("Q_IA={}, Q_IN={}".format(Q_IA, Q_IN))
return Q_IA
def best_center(center, width):
trials = [center + (width / 4.0) * x for x in range(-2, 3)]
scores = [(guess, get_score(guess, m0)) for guess in trials]
LOG.info(scores)
best = sorted(scores, cmp=lambda x, y: cmp(x[1], y[1]))
return best[0][0]
for i in range(params.num_iterations):
LOG.info("Estimate iteration: {}".format(i))
new_center = best_center(new_center, width)
LOG.info("Currently best center: {}".format(new_center))
width /= 2.0
try:
os.remove(tmp_output)
os.removedirs(tmp_dir)
except OSError:
LOG.info("Could not remove {} or {}".format(tmp_output, tmp_dir))
return new_center
def estimate_center_by_correlation(params):
"""Use correlation to estimate center of rotation for tomography."""
def flat_correct(flat, radio):
nonzero = np.where(radio != 0)
result = np.zeros_like(radio)
result[nonzero] = flat[nonzero] / radio[nonzero]
# log(1) = 0
result[result <= 0] = 1
return np.log(result)
first = read_image(get_filenames(params.projections)[0]).astype(np.float)
last_index = params.start + params.number if params.number else -1
last = read_image(get_filenames(params.projections)[last_index]).astype(np.float)
if params.darks and params.flats:
dark = read_image(get_filenames(params.darks)[0]).astype(np.float)
flat = read_image(get_filenames(params.flats)[0]) - dark
first = flat_correct(flat, first - dark)
last = flat_correct(flat, last - dark)
height = params.height if params.height else -1
y_region = slice(params.y, min(params.y + height, first.shape[0]), params.y_step)
first = first[y_region, :]
last = last[y_region, :]
return compute_rotation_axis(first, last)
def compute_rotation_axis(first_projection, last_projection):
"""
Compute the tomographic rotation axis based on cross-correlation technique.
*first_projection* is the projection at 0 deg, *last_projection* is the projection
at 180 deg.
"""
from scipy.signal import fftconvolve
width = first_projection.shape[1]
first_projection = first_projection - first_projection.mean()
last_projection = last_projection - last_projection.mean()
# The rotation by 180 deg flips the image horizontally, in order
# to do cross-correlation by convolution we must also flip it
# vertically, so the image is transposed and we can apply convolution
# which will act as cross-correlation
convolved = fftconvolve(first_projection, last_projection[::-1, :], mode='same')
center = np.unravel_index(convolved.argmax(), convolved.shape)[1]
return (width / 2.0 + center) / 2
tofu-0.12.0/tofu/tasks.py 0000664 0000000 0000000 00000002443 14237137211 0015227 0 ustar 00root root 0000000 0000000 import logging
from gi.repository import Ufo
LOG = logging.getLogger(__name__)
PLUGIN_MANAGER = Ufo.PluginManager()
def get_task(name, processing_node=None, **kwargs):
task = PLUGIN_MANAGER.get_task(name)
task.set_properties(**kwargs)
if processing_node and task.uses_gpu():
LOG.debug("Assigning task '%s' to node %d", name, processing_node.get_index())
task.set_proc_node(processing_node)
return task
def get_writer(params):
if 'dry_run' in params and params.dry_run:
LOG.debug("Discarding data output")
return get_task('null', download=True)
outname = params.output
LOG.debug("Writing output to {}".format(outname))
writer = get_task('write', filename=outname)
writer.props.append = params.output_append
if params.output_bitdepth != 32:
writer.props.bits = params.output_bitdepth
if params.output_minimum is not None and params.output_maximum is not None:
writer.props.minimum = params.output_minimum
writer.props.maximum = params.output_maximum
if hasattr (writer.props, 'bytes_per_file'):
writer.props.bytes_per_file = params.output_bytes_per_file
if hasattr(writer.props, 'tiff_bigtiff'):
writer.props.tiff_bigtiff = params.output_bytes_per_file > 2 ** 32
return writer
tofu-0.12.0/tofu/tests/ 0000775 0000000 0000000 00000000000 14237137211 0014667 5 ustar 00root root 0000000 0000000 tofu-0.12.0/tofu/tests/__init__.py 0000664 0000000 0000000 00000000000 14237137211 0016766 0 ustar 00root root 0000000 0000000 tofu-0.12.0/tofu/tests/composites/ 0000775 0000000 0000000 00000000000 14237137211 0017054 5 ustar 00root root 0000000 0000000 tofu-0.12.0/tofu/tests/composites/cmp.cm 0000664 0000000 0000000 00000005666 14237137211 0020171 0 ustar 00root root 0000000 0000000 {
"name": "cmp",
"caption": "cmp",
"models": {
"Null": {
"model": {
"caption": "Null",
"properties": {
"download": [
false,
true
],
"finish": [
false,
true
],
"durations": [
false,
true
]
}
},
"visible": true,
"position": {
"x": 756.0,
"y": 272.0
},
"name": "null"
},
"Read": {
"model": {
"caption": "Read",
"properties": {
"path": [
".",
true
],
"start": [
0,
false
],
"number": [
4294967295,
true
],
"step": [
1,
false
],
"y": [
0,
false
],
"height": [
0,
false
],
"y-step": [
1,
false
],
"convert": [
true,
false
],
"raw-width": [
0,
false
],
"raw-height": [
0,
false
],
"raw-bitdepth": [
0,
false
],
"raw-pre-offset": [
0,
false
],
"raw-post-offset": [
0,
false
],
"type": [
"unspecified",
false
],
"retries": [
0,
false
],
"retry-timeout": [
1,
false
]
}
},
"visible": true,
"position": {
"x": 266.0,
"y": 242.0
},
"name": "read"
}
},
"connections": [
[
"Read",
0,
"Null",
0
]
],
"links": []
} tofu-0.12.0/tofu/tests/composites/cmp_2.cm 0000664 0000000 0000000 00000003475 14237137211 0020406 0 ustar 00root root 0000000 0000000 {
"name": "cmp_2",
"caption": "cmp_2",
"models": {
"Dummy Data": {
"model": {
"caption": "Dummy Data",
"properties": {
"width": [
1,
true
],
"height": [
1,
true
],
"depth": [
1,
true
],
"number": [
1,
true
],
"init": [
0.0,
true
],
"metadata": [
false,
true
]
}
},
"visible": true,
"position": {
"x": 359.0,
"y": 357.0
},
"name": "dummy_data"
},
"Null": {
"model": {
"caption": "Null",
"properties": {
"download": [
false,
true
],
"finish": [
false,
true
],
"durations": [
false,
true
]
}
},
"visible": true,
"position": {
"x": 859.0,
"y": 459.0
},
"name": "null"
}
},
"connections": [
[
"Dummy Data",
0,
"Null",
0
]
],
"links": []
} tofu-0.12.0/tofu/tests/conftest.py 0000664 0000000 0000000 00000003044 14237137211 0017067 0 ustar 00root root 0000000 0000000 import pytest
from PyQt5.QtWidgets import QInputDialog
from tofu.flow.main import get_filled_registry
from tofu.flow.scene import UfoScene
from tofu.flow.propertylinksmodels import PropertyLinksModel, NodeTreeModel
@pytest.fixture(scope='function')
def nodes(monkeypatch):
reg = get_filled_registry()
scene = UfoScene(reg)
nodes = {}
# Composite node
for name in ['read', 'pad']:
model_cls = reg.create(name)
node = scene.create_node(model_cls)
node.graphics_object.setSelected(True)
monkeypatch.setattr(QInputDialog, "getText", lambda *args: ('cpm', True))
nodes['cpm'] = scene.create_composite()
nodes['cpm'].graphics_object.setSelected(False)
# Simple nodes
for i in range(5):
name = f'read_{i}' if i else 'read'
model_cls = reg.create('read')
nodes[name] = scene.create_node(model_cls)
model_cls = reg.create('image_viewer')
nodes['image_viewer'] = scene.create_node(model_cls)
model_cls = reg.create('average')
nodes['average'] = scene.create_node(model_cls)
return nodes
@pytest.fixture(scope='function')
def scene():
reg = get_filled_registry()
return UfoScene(reg)
@pytest.fixture(scope='function')
def scene_with_composite(nodes):
return UfoScene(nodes['cpm'].model._registry)
@pytest.fixture(scope='function')
def node_model():
model = NodeTreeModel()
model.setColumnCount(1)
return model
@pytest.fixture(scope='function')
def link_model(node_model):
model = PropertyLinksModel(node_model)
return model
tofu-0.12.0/tofu/tests/flow_util.py 0000664 0000000 0000000 00000001666 14237137211 0017256 0 ustar 00root root 0000000 0000000 def populate_link_model(link_model, nodes):
read = nodes['read']
read_2 = nodes['read_2']
composite = nodes['cpm']
records = [[read, read.model, 'number'],
[read_2, read_2.model, 'height'],
[composite, composite.model['Read'], 'y']]
for (i, (node, model, prop)) in enumerate(records):
link_model.add_item(node, model, prop, 0, i)
return records
def get_index_from_treemodel(node_model, row, prop_name):
item = node_model.item(row, 0)
i = 0
prop_item = item.child(i)
while prop_item.text() != prop_name:
i += 1
prop_item = item.child(i)
return node_model.indexFromItem(prop_item)
def add_nodes_to_scene(scene, model_names=None):
if not model_names:
model_names = ['read']
nodes = []
for name in model_names:
model_cls = scene.registry.create(name)
nodes.append(scene.create_node(model_cls))
return nodes
tofu-0.12.0/tofu/tests/test_flow_execution.py 0000664 0000000 0000000 00000011437 14237137211 0021340 0 ustar 00root root 0000000 0000000 import pytest
from tofu.flow.execution import get_gpu_splitting_models, UfoExecutor
from tofu.flow.main import get_filled_registry
from tofu.flow.scene import UfoScene
@pytest.fixture(scope='function')
def scene():
reg = get_filled_registry()
scene = UfoScene(reg)
for name in ['dummy_data', 'pad', 'null']:
# Set nodes as scene attributes for convenience
setattr(scene, name, scene.create_node(reg.create(name)))
scene.create_connection(scene.dummy_data['output'][0], scene.pad['input'][0])
scene.create_connection(scene.pad['output'][0], scene.null['input'][0])
return scene
@pytest.fixture(scope='function')
def executor():
return UfoExecutor()
class TestUfoExecutor:
def test_init(self, executor):
...
def test_reset(self, executor):
assert not executor._aborted
assert executor._schedulers == []
assert executor.num_generated == 0
def test_abort(self, executor):
self.called = False
def slot():
self.called = True
executor.execution_finished.connect(slot)
executor.abort()
assert self.called
def test_on_processed(self, executor):
self.num_generated = 0
def slot():
self.num_generated += 1
executor.processed_signal.connect(slot)
executor.on_processed(None)
executor.on_processed(None)
assert self.num_generated == executor.num_generated == 2
def test_setup_ufo_graph(self, qtbot, scene, executor):
graph = scene.get_simple_node_graphs()[0]
gpus = executor._resources.get_gpu_nodes()
assert gpus
executor.setup_ufo_graph(graph, gpu=gpus[0], region=None,
signalling_model=scene.dummy_data.model)
def test_run_ufo_graph(self, qtbot, scene, executor):
graph = scene.get_simple_node_graphs()[0]
gpus = executor._resources.get_gpu_nodes()
assert gpus
ufo_graph = executor.setup_ufo_graph(graph, gpu=gpus[0], region=None,
signalling_model=scene.dummy_data.model)
# Run with default scheduler
executor._run_ufo_graph(ufo_graph, False)
# Run with fixed scheduler
executor._run_ufo_graph(ufo_graph, True)
# def test_check_graph(self, qtbot, scene, executor):
# # TODO: implement this when memory-in is implemented and there is something to test
def test_run(self, qtbot, scene, executor):
def on_num_inputs_changed(number):
self.num_inputs = number
def on_processed(number):
self.num_processed = number
def on_execution_started():
self.started = True
def on_execution_finished():
self.finished = True
def on_exception_occured():
self.exception = True
scene.dummy_data.model['number'] = 10
graph = scene.get_simple_node_graphs()[0]
self.num_inputs = 0
self.num_processed = 0
self.started = False
self.finished = False
self.exception = None
executor.number_of_inputs_changed.connect(on_num_inputs_changed)
executor.processed_signal.connect(on_processed)
executor.execution_started.connect(on_execution_started)
executor.execution_finished.connect(on_execution_finished)
executor.exception_occured.connect(on_exception_occured)
with qtbot.waitSignal(signal=executor.execution_finished, timeout=100000):
executor.run(graph)
assert self.num_inputs == scene.dummy_data.model['number']
assert self.num_processed == scene.dummy_data.model['number']
assert self.started
assert self.finished
assert self.exception is None
scene.remove_node(scene.dummy_data)
# Create a reader and point it to a nonexistent path so that it raises an exception and
# check that this exception has been processed byt the executor
setattr(scene, 'read', scene.create_node(scene.registry.create('read')))
scene.create_connection(scene.read['output'][0], scene.pad['input'][0])
# Make sure the path is nonsense
scene.read.model['path'] = '/dfasf/fsdafsdaf/asd/asf'
scene.read.model['number'] = 10
graph = scene.get_simple_node_graphs()[0]
executor.swallow_run_exceptions = True
with qtbot.waitSignal(signal=executor.execution_finished):
executor.run(graph)
assert self.exception
def test_get_gpu_splitting_models(qtbot, scene, executor):
graph = scene.get_simple_node_graphs()[0]
assert len(get_gpu_splitting_models(graph)) == 0
scene.clear_scene()
scene.create_node(scene.registry.create('general_backproject'))
graph = scene.get_simple_node_graphs()[0]
assert len(get_gpu_splitting_models(graph)) == 1
tofu-0.12.0/tofu/tests/test_flow_main.py 0000664 0000000 0000000 00000046127 14237137211 0020265 0 ustar 00root root 0000000 0000000 import glob
import os
import pathlib
import pkg_resources
import pytest
import sys
from PyQt5.QtWidgets import QFileDialog, QInputDialog, QMessageBox
from xdg import xdg_data_home
from tofu.flow.execution import UfoExecutor
from tofu.flow.main import ApplicationWindow, get_filled_registry, GlobalExceptionHandler
from tofu.flow.scene import UfoScene
from tofu.flow.util import FlowError
from tofu.tests.flow_util import add_nodes_to_scene
@pytest.fixture(scope='function')
def app_window(qtbot, scene):
window = ApplicationWindow(scene)
qtbot.addWidget(window)
return window
class TestApplicationWindow:
def test_init(self, qtbot, app_window):
assert app_window.ufo_scene
assert app_window.executor
def test_on_save(self, monkeypatch, app_window):
def getSaveFileNameDefault(inst, header, path, fltr):
return (os.path.join(path, 'flow.flow'), True)
def getSaveFileName(inst, header, path, fltr):
return (os.path.join('foo', 'bar', 'flow.flow'), True)
# Don't actually write to disk
monkeypatch.setattr(UfoScene, "save", lambda *args: None)
# Default directory
monkeypatch.setattr(QFileDialog, "getSaveFileName", getSaveFileNameDefault)
app_window.on_save()
directory = os.path.join(xdg_data_home(), 'tofu', 'flows')
assert os.path.exists(directory)
assert app_window.last_dirs['scene'] == directory
# When user picks a different directory it must be remembered
monkeypatch.setattr(QFileDialog, "getSaveFileName", getSaveFileName)
app_window.on_save()
assert app_window.last_dirs['scene'] == os.path.join('foo', 'bar')
# And used the next time
monkeypatch.setattr(QFileDialog, "getSaveFileName", getSaveFileNameDefault)
app_window.on_save()
assert app_window.last_dirs['scene'] == os.path.join('foo', 'bar')
def test_on_open(self, monkeypatch, app_window):
def getOpenFileNameDefault(inst, header, path, fltr):
return (os.path.join(path, 'flow.flow'), True)
def getOpenFileName(inst, header, path, fltr):
return (os.path.join('foo', 'bar', 'flow.flow'), True)
# Don't actually read from disk
monkeypatch.setattr(UfoScene, "load", lambda *args: None)
# Default directory
monkeypatch.setattr(QFileDialog, "getOpenFileName", getOpenFileNameDefault)
app_window.on_open()
directory = os.path.join(xdg_data_home(), 'tofu', 'flows')
if not os.path.exists(directory):
directory = pathlib.Path.home()
assert app_window.last_dirs['scene'] == directory
# When user picks a different directory it must be remembered
monkeypatch.setattr(QFileDialog, "getOpenFileName", getOpenFileName)
app_window.on_open()
assert app_window.last_dirs['scene'] == os.path.join('foo', 'bar')
# And used the next time
monkeypatch.setattr(QFileDialog, "getOpenFileName", getOpenFileNameDefault)
app_window.on_open()
assert app_window.last_dirs['scene'] == os.path.join('foo', 'bar')
def test_on_exception_occured(self, qtbot, monkeypatch, app_window):
def exec_(inst):
self.message_shown = True
self.message_shown = False
monkeypatch.setattr(QMessageBox, "exec_", exec_)
app_window.on_exception_occured('foo')
assert self.message_shown
def test_on_number_of_inputs_changed(self, qtbot, app_window):
app_window.on_number_of_inputs_changed(123)
assert app_window.progress_bar.maximum() == 123
def test_on_processed(self, qtbot, app_window):
app_window.on_number_of_inputs_changed(100)
app_window.on_processed(10)
assert app_window.progress_bar.value() == 11
def test_on_nodes_duplicated(self, qtbot, app_window):
node = add_nodes_to_scene(app_window.ufo_scene)[0]
node.graphics_object.setSelected(True)
app_window.ufo_scene.copy_nodes()
nodes = list(app_window.ufo_scene.nodes.values())
assert nodes[0].graphics_object.pos().y() != nodes[1].graphics_object.pos().y()
def test_on_selection_menu_about_to_show(self, qtbot, monkeypatch, app_window):
# Nothing selected
app_window.on_selection_menu_about_to_show()
assert not app_window.edit_composite_action.isEnabled()
assert not app_window.expand_composite_action.isEnabled()
assert not app_window.export_composite_action.isEnabled()
# Only non-composite nodes
nodes = add_nodes_to_scene(app_window.ufo_scene, model_names=['read', 'average', 'null'])
app_window.on_selection_menu_about_to_show()
assert not app_window.edit_composite_action.isEnabled()
assert not app_window.expand_composite_action.isEnabled()
assert not app_window.export_composite_action.isEnabled()
# One composite
monkeypatch.setattr(QInputDialog, "getText", lambda *args: ('cpm', True))
for i in range(2):
nodes[i].graphics_object.setSelected(True)
app_window.ufo_scene.create_composite()
app_window.on_selection_menu_about_to_show()
assert app_window.edit_composite_action.isEnabled()
assert app_window.expand_composite_action.isEnabled()
assert app_window.export_composite_action.isEnabled()
# More composites
monkeypatch.setattr(QInputDialog, "getText", lambda *args: ('cpm_2', True))
app_window.ufo_scene.clearSelection()
nodes[-1].graphics_object.setSelected(True)
app_window.ufo_scene.create_composite()
for node in app_window.ufo_scene.nodes.values():
node.graphics_object.setSelected(True)
app_window.on_selection_menu_about_to_show()
assert not app_window.edit_composite_action.isEnabled()
assert app_window.expand_composite_action.isEnabled()
assert not app_window.export_composite_action.isEnabled()
def test_skip_action(self, qtbot, app_window):
# No nodes selected, menu item must be disabled
app_window.on_selection_menu_about_to_show()
assert not app_window.skip_action.isEnabled()
# Add some nodes, conect them and disable one
nodes = add_nodes_to_scene(app_window.ufo_scene, model_names=['read', 'average', 'null'])
app_window.ufo_scene.create_connection(nodes[0]['output'][0], nodes[1]['input'][0])
app_window.ufo_scene.create_connection(nodes[1]['output'][0], nodes[2]['input'][0])
average = nodes[1]
average.graphics_object.setSelected(True)
app_window.on_selection_menu_about_to_show()
# Nodes selected, menu item must be enabled
assert app_window.skip_action.isEnabled()
def test_on_edit_composite(self, qtbot, scene_with_composite, app_window):
app_window.ufo_scene = scene_with_composite
node = add_nodes_to_scene(app_window.ufo_scene, model_names=['cpm'])[0]
node.graphics_object.setSelected(True)
app_window.on_edit_composite()
qtbot.addWidget(node.model._other_view)
assert node.model.is_editing
def test_on_create_composite(self, qtbot, monkeypatch, scene_with_composite, app_window):
monkeypatch.setattr(QInputDialog, "getText", lambda *args: ('cpm', True))
nodes = add_nodes_to_scene(app_window.ufo_scene, model_names=['read', 'pad'])
# Link a model to the slider
model = nodes[0].model
view_item = model._view._properties['number'].view_item
app_window.on_item_focus_in(view_item, 'number', 'Read', model)
# Create a composite
for node in app_window.ufo_scene.nodes.values():
node.graphics_object.setSelected(True)
app_window.on_create_composite()
composite = list(app_window.ufo_scene.nodes.values())[0].model
slider_model, prop_name = app_window.run_slider_key
assert slider_model == composite.get_model_from_path(['Read'])
assert prop_name == 'number'
def test_on_item_focus_in(self, qtbot, app_window, scene_with_composite):
read, pad = add_nodes_to_scene(app_window.ufo_scene, model_names=['read', 'pad'])
# Simple node
model = read.model
view_item = model._view._properties['number'].view_item
app_window.on_item_focus_in(view_item, 'number', model.caption, model)
slider_model, prop_name = app_window.run_slider_key
assert slider_model == model
assert prop_name == 'number'
app_window.fix_run_slider.setChecked(False)
model = pad.model
view_item = model._view._properties['y'].view_item
app_window.on_item_focus_in(view_item, 'y', model.caption, model)
slider_model, prop_name = app_window.run_slider_key
assert slider_model == model
assert prop_name == 'y'
# Focus gets another widget, but the run slider must be linked to the one focused before the
# fix option is checked
app_window.fix_run_slider.setChecked(True)
model = read.model
view_item = model._view._properties['number'].view_item
app_window.on_item_focus_in(view_item, 'number', model.caption, model)
slider_model, prop_name = app_window.run_slider_key
assert slider_model == pad.model
assert prop_name == 'y'
def test_on_node_deleted(self, qtbot, monkeypatch, app_window, scene_with_composite):
app_window.ufo_scene = scene_with_composite
cpm, cpm_2, read = add_nodes_to_scene(app_window.ufo_scene,
model_names=['cpm', 'cpm', 'read'])
# Simple node
model = read.model
view_item = model._view._properties['number'].view_item
app_window.on_item_focus_in(view_item, 'number', model.caption, model)
# remove in the scene doesn't seem to emit the signal, so use the window
app_window.on_node_deleted(read)
slider_model, prop_name = app_window.run_slider_key
assert slider_model is None
assert prop_name is None
# Composite node
model = cpm.model.get_model_from_path(['Read'])
view_item = model._view._properties['number'].view_item
app_window.on_item_focus_in(view_item, 'number', 'cpm->Read', model)
# remove in the scene doesn't seem to emit the signal, so use the window
app_window.on_node_deleted(cpm)
slider_model, prop_name = app_window.run_slider_key
assert slider_model is None
assert prop_name is None
# Nested composite node
cpm_2.graphics_object.setSelected(True)
monkeypatch.setattr(QInputDialog, "getText", lambda *args: ('parent', True))
app_window.on_create_composite()
node = app_window.ufo_scene.selected_nodes()[0]
model = node.model.get_model_from_path(['cpm 2', 'Read'])
view_item = model._view._properties['number'].view_item
app_window.on_item_focus_in(view_item, 'number', 'parent->cpm 2->Read', model)
# remove in the scene doesn't seem to emit the signal, so use the window
app_window.on_node_deleted(node)
slider_model, prop_name = app_window.run_slider_key
assert slider_model is None
assert prop_name is None
def test_on_expand_composite(self, qtbot, scene_with_composite, app_window):
app_window.ufo_scene = scene_with_composite
nodes = add_nodes_to_scene(app_window.ufo_scene, model_names=['cpm', 'cpm'])
for node in nodes:
node.graphics_object.setSelected(True)
app_window.on_expand_composite()
captions = {node.model.caption for node in app_window.ufo_scene.nodes.values()}
assert captions == {'Read 2', 'Pad 2', 'Read', 'Pad'}
# Run slider
# Create yet another composite and select a reader inside
node = add_nodes_to_scene(app_window.ufo_scene, model_names=['cpm'])[0]
model = node.model.get_model_from_path(['Read'])
view_item = model._view._properties['number'].view_item
app_window.on_item_focus_in(view_item, 'number', 'cpm->Read', model)
node.graphics_object.setSelected(True)
app_window.on_expand_composite()
# After expansion, the reader's index will be 3
slider_model, prop_name = app_window.run_slider_key
assert slider_model.caption == 'Read 3'
assert prop_name == 'number'
def test_on_import_composites(self, qtbot, monkeypatch, app_window):
tests_directory = pkg_resources.resource_filename(__name__, 'composites')
def getOpenFileNamesDefault(inst, header, path, fltr):
# Let's pretend there are files
file_names = [os.path.join(path, 'foo.cm')]
return (file_names, True)
def getOpenFileNames(inst, header, path, fltr):
file_names = sorted(glob.glob(os.path.join(tests_directory, '*.cm')))
return (file_names, True)
def exec_(inst):
self.message_shown = True
monkeypatch.setattr(QMessageBox, "exec_", exec_)
# Nothing opened, nothing happens
monkeypatch.setattr(QFileDialog, "getOpenFileNames", lambda *args: ([], True))
app_window.on_import_composites()
# Default directory
monkeypatch.setattr(QFileDialog, "getOpenFileNames", getOpenFileNamesDefault)
directory = os.path.join(xdg_data_home(), 'tofu', 'flows', 'composites')
if not os.path.exists(directory):
directory = pathlib.Path.home()
try:
app_window.on_import_composites()
except FileNotFoundError:
# We don't care if there are files, just the last_dirs setting is important
pass
assert app_window.last_dirs['composite'] == directory
# It's possible to open more than one at a time
monkeypatch.setattr(QFileDialog, "getOpenFileNames", getOpenFileNames)
app_window.on_import_composites()
assert 'cmp' in app_window.ufo_scene.registry.registered_model_creators()
assert 'cmp_2' in app_window.ufo_scene.registry.registered_model_creators()
# When user picks a different directory it must be remembered
assert app_window.last_dirs['composite'] == tests_directory
# And used the next time
self.message_shown = False
app_window.on_import_composites()
assert app_window.last_dirs['composite'] == tests_directory
# Message about overwriting models must be shown
assert self.message_shown
def test_on_export_composite(self, qtbot, monkeypatch, scene_with_composite, app_window):
tests_directory = pkg_resources.resource_filename(__name__, 'composites')
def getSaveFileNameDefault(inst, header, path, fltr):
return (os.path.join(path, self.file_name), True)
def getSaveFileName(inst, header, path, fltr):
return (os.path.join(tests_directory, self.file_name), True)
def export_composite(inst, node, file_name):
self.final_file_name = file_name
# Nothing selected, must silently pass
app_window.on_export_composite()
# Make a composite node
app_window.ufo_scene = scene_with_composite
node = add_nodes_to_scene(app_window.ufo_scene, model_names=['cpm'])[0]
node.graphics_object.setSelected(True)
monkeypatch.setattr(ApplicationWindow, "export_composite", export_composite)
# Default directory
monkeypatch.setattr(QFileDialog, "getSaveFileName", getSaveFileNameDefault)
self.file_name = 'composite'
directory = os.path.join(xdg_data_home(), 'tofu', 'flows', 'composites')
app_window.on_export_composite()
assert self.final_file_name.endswith('.cm') and not self.final_file_name.endswith('.cm.cm')
assert os.path.exists(directory)
assert app_window.last_dirs['composite'] == directory
# When user picks a different directory it must be remembered
monkeypatch.setattr(QFileDialog, "getSaveFileName", getSaveFileName)
app_window.on_export_composite()
assert self.final_file_name.endswith('.cm') and not self.final_file_name.endswith('.cm.cm')
assert app_window.last_dirs['composite'] == tests_directory
# And used the next time
monkeypatch.setattr(QFileDialog, "getSaveFileName", getSaveFileNameDefault)
app_window.on_export_composite()
assert app_window.last_dirs['composite'] == tests_directory
# .cm must not be added if it's present in the file name
self.file_name = 'composite.cm'
app_window.on_export_composite()
assert self.final_file_name.endswith('.cm') and not self.final_file_name.endswith('.cm.cm')
def test_on_reset_view(self, qtbot, app_window):
app_window.flow_view.scale_up()
app_window.on_reset_view()
assert app_window.flow_view.transform().m11() == pytest.approx(1)
assert app_window.flow_view.transform().m22() == pytest.approx(1)
def test_on_property_links_action(self, qtbot, app_window):
qtbot.addWidget(app_window.property_links_widget)
app_window.property_links_widget.show()
assert app_window.property_links_widget.isVisible()
def test_on_run(self, qtbot, monkeypatch, app_window):
def executor_run(inst, graph):
self.ran = True
monkeypatch.setattr(UfoExecutor, "run", executor_run)
nodes = add_nodes_to_scene(app_window.ufo_scene,
model_names=['read', 'read', 'flat_field_correct', 'null'])
i_0, i_1, ffc, null = nodes
# No connections -> many graphs
with pytest.raises(FlowError):
app_window.on_run()
assert app_window.run_action.isEnabled()
app_window.ufo_scene.create_connection(i_0['output'][0], ffc['input'][0])
app_window.ufo_scene.create_connection(i_1['output'][0], ffc['input'][1])
app_window.ufo_scene.create_connection(ffc['output'][0], null['input'][0])
# One ffc input is not connected
with pytest.raises(FlowError):
app_window.on_run()
assert app_window.run_action.isEnabled()
# All connections present -> must run
i_2 = add_nodes_to_scene(app_window.ufo_scene, model_names=['read'])[0]
app_window.ufo_scene.create_connection(i_2['output'][0], ffc['input'][2])
self.ran = False
app_window.on_run()
assert self.ran
assert not app_window.run_action.isEnabled()
def test_on_execution_finished(self, qtbot, app_window):
app_window.run_action.setEnabled(False)
app_window.progress_bar.setMaximum(100)
app_window.progress_bar.setValue(50)
app_window.on_execution_finished()
assert app_window.progress_bar.value() == -1
assert app_window.run_action.isEnabled()
def test_global_exception_handler(qtbot):
handler = GlobalExceptionHandler()
def slot(text):
handler.called_signal = True
handler.exception_occured.connect(slot)
handler.called_signal = False
try:
raise FlowError('foo')
except:
# Call the hook explicitly, sys.excinfo = ... doesn't seem to have effect
handler.excepthook(*sys.exc_info())
assert handler.called_signal
def test_get_filled_registry():
registry = get_filled_registry()
assert 'read' in registry.registered_model_creators()
tofu-0.12.0/tofu/tests/test_flow_models.py 0000664 0000000 0000000 00000164325 14237137211 0020625 0 ustar 00root root 0000000 0000000 import pytest
import numpy as np
from PyQt5.QtCore import Qt
from PyQt5.QtGui import QValidator
from PyQt5.QtWidgets import QFileDialog, QInputDialog, QLineEdit
from tofu.flow.main import get_filled_registry
from tofu.flow.models import (CheckBoxViewItem, ComboBoxViewItem, get_composite_model_class,
get_composite_model_classes, get_composite_model_classes_from_json,
get_ufo_model_class, get_ufo_model_classes, ImageViewerModel,
IntQLineEditViewItem, MultiPropertyView, NumberQLineEditViewItem,
PropertyModel, PropertyView, QLineEditViewItem,
RangeQLineEditViewItem, UfoGeneralBackprojectModel, UfoIntValidator,
UfoMemoryOutModel, UfoModelError, UfoRangeValidator, UfoReadModel,
UfoRetrievePhaseModel, UfoModel, UfoTaskModel,
UfoVaryingInputModel, UfoWriteModel, ViewItem)
from tofu.flow.scene import UfoScene
from tofu.flow.util import CompositeConnection, MODEL_ROLE, PROPERTY_ROLE
from tofu.tests.flow_util import populate_link_model
def check_property_changed_emit(qtbot, view_item, expected, gui_func, gui_args, gui_kwargs=None,
show=False):
if gui_kwargs is None:
gui_kwargs = {}
def on_changed(vit):
vit.change_called = True
view_item.change_called = False
view_item.property_changed.connect(on_changed)
qtbot.addWidget(view_item.widget)
if show:
# without show the mouse click for QCheckBox doesn't happen, bug in pytest-qt?
view_item.widget.show()
# Store old value for later check of programmatic change
old_value = view_item.get()
# Simulate user interaction
gui_func(*gui_args, **gui_kwargs)
# Value must have been set
assert view_item.get() == expected
# Signal must have been emitted
assert view_item.change_called
view_item.change_called = False
view_item.set(old_value)
# Signal must be emitted only on user interacion, not programmatic access
assert not view_item.change_called
def make_properties():
return {
'int': [IntQLineEditViewItem(0, 100, default_value=10), True],
'float': [NumberQLineEditViewItem(0, 100, default_value=0), True],
'string': [QLineEditViewItem(default_value='foo'), True],
'range': [RangeQLineEditViewItem(default_value=[1, 2, 3], num_items=3, is_float=True),
True],
'choices': [ComboBoxViewItem(['a', 'b', 'c']), True],
'check': [CheckBoxViewItem(checked=True), True]
}
class DummyPropertyModel(PropertyModel):
def make_properties(self):
return make_properties()
@pytest.fixture(scope='function')
def property_view():
return PropertyView(properties=make_properties(), scrollable=False)
@pytest.fixture(scope='function')
def multi_property_view(nodes):
groups = {nodes['cpm'].model: True, nodes['read'].model: False}
return MultiPropertyView(groups=groups)
def make_composite_model_class(nodes, name='foobar'):
# We want to connect cpm:Pad to average, thus we need to get the outside port of the cpm
# composite which corresponds to the pad model
pad_index = nodes['cpm'].model.get_outside_port('Pad', 'output', 0)[1]
connections = [CompositeConnection('cpm', pad_index, 'Average', 0)]
state = [('cpm', nodes['cpm'].model.save(), True, None),
('average', nodes['average'].model.save(), True, None)]
return get_composite_model_class(name, state, connections)
def create_scene(qtbot, registry):
scene = UfoScene(registry=registry)
if scene.views():
for view in scene.views():
qtbot.addWidget(view)
return scene
def make_composite_node_in_scene(qtbot, nodes):
model_cls = make_composite_model_class(nodes)
registry = get_filled_registry()
# Register both composites so that we can create them
registry.register_model(nodes['cpm'].model.__class__,
category='Composite', registry=registry)
registry.register_model(model_cls, category='Composite', registry=registry)
scene = create_scene(qtbot, registry)
node = scene.create_node(model_cls)
return (scene, node)
@pytest.fixture(scope='function')
def composite_model(nodes):
# Make sure 'cpm', which is inside this composite model, has been registered
registry = nodes['cpm'].model._registry
model_cls = make_composite_model_class(nodes)
registry.register_model(model_cls, category='Composite', registry=registry)
return model_cls(registry=registry)
@pytest.fixture(scope='function')
def general_backproject(qtbot):
model = UfoGeneralBackprojectModel()
qtbot.addWidget(model.embedded_widget())
return model
@pytest.fixture(scope='function')
def read_model(qtbot):
model = UfoReadModel()
qtbot.addWidget(model.embedded_widget())
return model
@pytest.fixture(scope='function')
def write_model(qtbot):
model = UfoWriteModel()
qtbot.addWidget(model.embedded_widget())
return model
@pytest.fixture(scope='function')
def memory_out_model(qtbot):
model = UfoMemoryOutModel()
model['width'] = 100
model['height'] = 100
qtbot.addWidget(model.embedded_widget())
return model
@pytest.fixture(scope='function')
def image_viewer_model(qtbot):
model = ImageViewerModel()
qtbot.addWidget(model.embedded_widget())
return model
def test_ufo_int_validator():
validator = UfoIntValidator(-10, 10)
def check(input_str, expected):
assert validator.validate(input_str, -1)[0] == expected
check('0', QValidator.Acceptable)
check('1', QValidator.Acceptable)
check('-1', QValidator.Acceptable)
check('101', QValidator.Intermediate)
check('-101', QValidator.Intermediate)
check('-', QValidator.Intermediate)
check('1.', QValidator.Invalid)
check('1.0', QValidator.Invalid)
check('asdf', QValidator.Invalid)
validator = UfoIntValidator(3, 10)
check('1', QValidator.Intermediate)
def test_ufo_range_validator():
def check(validator, input_str, expected):
assert validator.validate(input_str, len(input_str))[0] == expected
# Integer
validator = UfoRangeValidator(num_items=3, is_float=False)
check(validator, ',,', QValidator.Intermediate)
check(validator, ' ,,', QValidator.Intermediate)
check(validator, '1,1,', QValidator.Intermediate)
check(validator, ',1,', QValidator.Intermediate)
check(validator, '1,-2,3', QValidator.Acceptable)
check(validator, '1,1.0,1', QValidator.Invalid)
check(validator, '-1,s,-1', QValidator.Invalid)
check(validator, '1,1,1,1', QValidator.Invalid)
check(validator, '1,1,1,', QValidator.Invalid)
# Float
validator = UfoRangeValidator(num_items=3, is_float=True)
check(validator, ',,', QValidator.Intermediate)
check(validator, ' ,,', QValidator.Intermediate)
check(validator, '.,,', QValidator.Intermediate)
check(validator, '.e,,', QValidator.Intermediate)
check(validator, '.e-,,', QValidator.Intermediate)
check(validator, '.e+,,', QValidator.Intermediate)
check(validator, '1.0e,,', QValidator.Intermediate)
check(validator, '1.0e+,,', QValidator.Intermediate)
check(validator, '1.0e-,,', QValidator.Intermediate)
check(validator, '1e,,', QValidator.Intermediate)
check(validator, '1e+,,', QValidator.Intermediate)
check(validator, '1e-,,', QValidator.Intermediate)
check(validator, '.1e,,', QValidator.Intermediate)
check(validator, '.1e+,,', QValidator.Intermediate)
check(validator, '.1e-,,', QValidator.Intermediate)
check(validator, '1,1,1', QValidator.Acceptable)
check(validator, '-1,1,1', QValidator.Acceptable)
check(validator, '1.,1.,1', QValidator.Acceptable)
check(validator, '-1.,1.,1', QValidator.Acceptable)
check(validator, '1.0e1,1.0,1', QValidator.Acceptable)
check(validator, '1.0e+1,1.0,1', QValidator.Acceptable)
check(validator, '1.0e-1,1.0,1', QValidator.Acceptable)
check(validator, '.1,1.0,1', QValidator.Acceptable)
check(validator, '.1e-1,1.0,1', QValidator.Acceptable)
check(validator, '.1e+1,1.0,1', QValidator.Acceptable)
check(validator, '.1e1,1.0,1', QValidator.Acceptable)
check(validator, 'e,,', QValidator.Invalid)
check(validator, 'e.,,', QValidator.Invalid)
check(validator, '+e,,', QValidator.Invalid)
check(validator, '-e,,', QValidator.Invalid)
check(validator, '+e.,,', QValidator.Invalid)
check(validator, '-e.,,', QValidator.Invalid)
check(validator, '1+,,', QValidator.Invalid)
check(validator, '1-,,', QValidator.Invalid)
check(validator, 'gfd,1,3', QValidator.Invalid)
def test_view_item_init(qtbot):
def get(inst):
return inst.widget.text()
def set(inst, value):
inst.widget.setText(value)
def on_changed(vit):
vit.change_called = True
ViewItem.get = get
ViewItem.set = set
edit = QLineEdit()
qtbot.addWidget(edit)
vit = ViewItem(edit, default_value='foo', tooltip='tooltip')
edit.textEdited.connect(vit.on_changed)
assert vit.widget.toolTip() == 'tooltip'
assert vit.widget.text() == 'foo'
check_property_changed_emit(qtbot, vit, 'fooa', qtbot.keyClick, (edit, 'a'))
def test_check_box_view_item(qtbot):
assert CheckBoxViewItem(checked=True).get()
vit = CheckBoxViewItem(checked=False, tooltip='tooltip')
assert vit.widget.toolTip() == 'tooltip'
assert not vit.get()
check_property_changed_emit(qtbot, vit, True, qtbot.mouseClick,
(vit.widget, Qt.LeftButton), show=True)
def test_combo_box_view_item(qtbot):
items = ['a', 'b', 'c']
vit = ComboBoxViewItem(items, default_value='b', tooltip='tooltip')
assert vit.widget.toolTip() == 'tooltip'
assert vit.get() == 'b'
check_property_changed_emit(qtbot, vit, 'c', qtbot.keyClick, (vit.widget, 'c'))
def test_qline_edit_view_item(qtbot):
vit = QLineEditViewItem(default_value='foo', tooltip='tooltip')
assert vit.widget.toolTip() == 'tooltip'
assert vit.get() == 'foo'
check_property_changed_emit(qtbot, vit, 'fooc', qtbot.keyClick, (vit.widget, 'c'))
def test_number_qline_edit_view_item(qtbot):
with pytest.raises(ValueError):
NumberQLineEditViewItem(-100, 100, default_value=1000)
with pytest.raises(ValueError):
NumberQLineEditViewItem(-100, 100, default_value=-1000)
vit = NumberQLineEditViewItem(-100., 100., default_value=0., tooltip='tooltip')
assert vit.widget.toolTip().startswith('tooltip')
assert vit.get() == 0
# is 0.0, after key click "1" will be 0.01
check_property_changed_emit(qtbot, vit, 0.01, qtbot.keyClick, (vit.widget, '1'))
def test_int_qline_edit_view_item(qtbot):
with pytest.raises(ValueError):
IntQLineEditViewItem(-100, 100, default_value=1000)
with pytest.raises(ValueError):
IntQLineEditViewItem(-100, 100, default_value=-1000)
vit = IntQLineEditViewItem(-100, 100, default_value=0, tooltip='tooltip')
assert vit.widget.toolTip().startswith('tooltip')
assert vit.get() == 0
# is 0, after key click "1" will be 01, thus 1
check_property_changed_emit(qtbot, vit, 1, qtbot.keyClick, (vit.widget, '1'))
def test_range_edit_view_item(qtbot):
vit = RangeQLineEditViewItem(default_value=[1.0, 2.0, 3.0], tooltip='tooltip')
assert vit.widget.toolTip().startswith('tooltip')
assert vit.get() == [1.0, 2.0, 3.0]
# Last is 3.0, after key click "1" will be 3.01
check_property_changed_emit(qtbot, vit, [1.0, 2.0, 3.01], qtbot.keyClick, (vit.widget, '1'))
class TestPropertyView:
def test_init(self, qtbot, property_view):
assert len(property_view.property_names) > 0
# Defaults must pass
PropertyView()
def test_get_property(self, qtbot, property_view):
assert property_view.get_property('int') == 10
def test_set_property(self, qtbot, property_view):
property_view.set_property('int', 50)
assert property_view.get_property('int') == 50
def test_on_property_changed(self, qtbot, property_view):
widget = property_view._properties['int'].view_item.widget
qtbot.addWidget(widget)
qtbot.keyClick(widget, '0')
assert property_view.get_property('int') == 100
def test_is_property_visible(self, qtbot, property_view):
assert property_view.is_property_visible('int')
def test_set_property_visible(self, qtbot, property_view):
visible = not property_view.is_property_visible('int')
property_view.set_property_visible('int', visible)
assert property_view.is_property_visible('int') == visible
def test_restore_properties(self, qtbot, property_view):
props = property_view.export_properties()
property_view.set_property('int', props['int'][0] + 1)
property_view.restore_properties(props)
assert property_view.get_property('int') == props['int'][0]
def test_export_properties(self, qtbot, property_view):
props = property_view.export_properties()
assert 'int' in props
assert props['int'][0] == property_view.get_property('int')
assert props['int'][1] == property_view.is_property_visible('int')
class TestMultiPropertyView:
def test_init(self, qtbot, multi_property_view):
assert len(list(iter(multi_property_view))) == 2
def test_getitem(self, qtbot, multi_property_view, nodes):
assert multi_property_view['cpm'] == nodes['cpm'].model
def test_contains(self, qtbot, multi_property_view):
assert 'cpm' in multi_property_view
assert 'foo' not in multi_property_view
def test_iter(self, qtbot, multi_property_view):
assert set(list(iter(multi_property_view))) == set(['cpm', 'Read'])
def test_export_groups(self, qtbot, multi_property_view):
state = multi_property_view.export_groups()
multi_property_view.set_group_visible('Read', False)
assert state['Read']['model']['caption'] == 'Read'
assert not state['Read']['visible']
def test_restore_groups(self, qtbot, multi_property_view, nodes):
multi_property_view['Read']['number'] = 100
state = multi_property_view.export_groups()
multi_property_view['Read']['number'] = 1000
multi_property_view.restore_groups(state)
assert multi_property_view['Read']['number']
def test_set_group_visible(self, qtbot, multi_property_view):
visible = not multi_property_view.is_group_visible('cpm')
multi_property_view.set_group_visible('cpm', visible)
assert multi_property_view.is_group_visible('cpm') == visible
def test_is_group_visible(self, qtbot, multi_property_view):
assert multi_property_view.is_group_visible('cpm')
assert not multi_property_view.is_group_visible('Read')
class TestUfoModel:
def test_init(self):
model = UfoModel()
assert model.caption == model.base_caption
def test_restore(self):
model = UfoModel()
state = {'caption': 'foo'}
old_caption = model.caption
model.restore(state, restore_caption=False)
assert model.caption == old_caption
model.restore(state, restore_caption=True)
assert model.caption == 'foo'
# 'caption' not in state, the old one must be preserved
model = UfoModel()
old_caption = model.caption
model.restore({}, restore_caption=True)
assert model.caption == old_caption
def save(self):
model = UfoModel()
model.caption = 'foo'
assert model.save()['caption'] == 'foo'
class TestPropertyModel:
def test_init(self, qtbot):
PropertyModel()
model = DummyPropertyModel()
# make_properties must be called
assert set(model.properties) == set(make_properties().keys())
def test_getitem(self, qtbot):
model = DummyPropertyModel()
model['int']
with pytest.raises(KeyError):
model['foo']
def test_setitem(self, qtbot):
model = DummyPropertyModel()
model['int'] = 132
assert model['int'] == 132
def test_contains(self, qtbot):
model = DummyPropertyModel()
assert 'int' in model
assert 'foo' not in model
def test_iter(self, qtbot):
model = DummyPropertyModel()
assert set(iter(model)) == set(make_properties().keys())
def test_on_property_changed(self, qtbot):
def callback(model, name, value):
self.called_name = name
self.called_value = value
model = DummyPropertyModel()
model.property_changed.connect(callback)
widget = model._view._properties['int'].view_item.widget
qtbot.addWidget(widget)
qtbot.keyClick(widget, '0')
assert self.called_value == model['int']
assert self.called_name == 'int'
def test_make_properties(self, qtbot):
props = DummyPropertyModel().make_properties()
assert props.keys() == make_properties().keys()
assert PropertyModel().make_properties() == {}
def test_copy_properties(self, qtbot):
model = DummyPropertyModel()
model['int'] = 123
visible = not model._view.is_property_visible('int')
model._view.set_property_visible('int', visible)
properties = model.copy_properties()
# It has to be a deep copy, so changing the model properties cannot affect the copy
model['int'] = 12
model._view.set_property_visible('int', not visible)
assert properties['int'][0].get() == 123
assert properties['int'][1] == visible
def test_embedded_widget(self, qtbot):
assert PropertyModel().embedded_widget() is None
assert isinstance(DummyPropertyModel().embedded_widget(), PropertyView)
def test_restore(self, qtbot):
model = DummyPropertyModel()
state = model.save()
old_value = model['int']
old_caption = model.caption
visible = not model._view.is_property_visible('int')
model['int'] = old_value + 1
model._view.set_property_visible('int', visible)
model.caption = 'Foo'
model.restore(state, restore_caption=False)
assert model['int'] == old_value
assert model._view.is_property_visible('int') == (not visible)
assert model.caption == 'Foo'
model.restore(state, restore_caption=True)
assert model.caption == old_caption
def test_save(self, qtbot):
model = DummyPropertyModel()
old_value = model['int']
visible = not model._view.is_property_visible('int')
model['int'] = old_value + 1
model._view.set_property_visible('int', visible)
model.caption = 'Foo'
state = model.save()
assert state['properties']['int'][0] == old_value + 1
assert state['properties']['int'][1] == visible
assert state['caption'] == 'Foo'
class TestUfoTaskModel:
def test_init(self, qtbot):
model = UfoTaskModel('flat-field-correct')
assert model.properties
# A task doesn't need any special treatment by default
assert not model.expects_multiple_inputs
assert not model.can_split_gpu_work
assert not model.needs_fixed_scheduler
def test_make_properties(self, qtbot):
model = UfoTaskModel('flat-field-correct')
# Config takes effect
assert not model._view.is_property_visible('dark-scale')
def test_create_ufo_task(self, qtbot):
model = UfoTaskModel('flat-field-correct')
model['dark-scale'] = 12.3
task = model.create_ufo_task()
assert task.props.dark_scale == pytest.approx(12.3)
def test_uses_gpu(self, qtbot):
model = UfoTaskModel('flat-field-correct')
assert model.uses_gpu
model = UfoTaskModel('read')
assert not model.uses_gpu
def test_get_ufo_model_class(qtbot):
# flat correction is a fairly complicated task to test
task_name = 'flat-field-correct'
model_cls = get_ufo_model_class(task_name)
# Model class attributes
assert model_cls.name == 'flat_field_correct'
model = model_cls()
# Model instance attributes
assert model.num_ports['input'] == 3
assert model.num_ports['output'] == 1
assert model.port_caption['input'][0] == 'radios'
assert model.port_caption['input'][1] == 'darks'
assert model.port_caption['input'][2] == 'flats'
assert model.port_caption['output'][0] == ''
class TestBaseCompositeModel:
def test_init(self, qtbot, monkeypatch, composite_model, scene):
# cpm has 1 input and 2 outputs (read and pad are not connected) and average has 1 input and
# 1 output, but cpm is connected with average, which reduces both port types by 1
assert composite_model.num_ports['input'] == 1
assert composite_model.num_ports['output'] == 2
for port_type in ['input', 'output']:
for i in range(composite_model.num_ports[port_type]):
submodel, j = composite_model.get_model_and_port_index(port_type, i)
subcaption = submodel.port_caption[port_type][j]
if subcaption:
subcaption = ':' + subcaption
assert (composite_model.port_caption[port_type][i] == submodel.caption + subcaption)
assert composite_model._view
# num-inputs must take effect
monkeypatch.setattr(QInputDialog, "getInt", lambda *args, **kwargs: (2, True))
monkeypatch.setattr(QInputDialog, "getText", lambda *args, **kwargs: ('with-pr', True))
node = scene.create_node(scene.registry.create('retrieve_phase'))
node.graphics_object.setSelected(True)
node = scene.create_composite()
assert node.model.get_model_from_path(['Retrieve Phase']).num_ports['input'] == 2
# and it must not affect default registry creators
kwargs = scene.registry.registered_model_creators()['retrieve_phase'][1]
assert 'num_inputs' not in kwargs
def test_getitem(self, qtbot, composite_model, nodes):
assert composite_model['cpm']
assert composite_model['Average']
with pytest.raises(KeyError):
composite_model['foo']
def test_contains(self, qtbot, composite_model, nodes):
assert 'cpm' in composite_model
assert 'foo' not in composite_model
def test_iter(self, qtbot, composite_model):
assert set(list(iter(composite_model))) == set(['cpm', 'Average'])
def test_get_descendant_graph(self, qtbot, monkeypatch, composite_model, nodes):
graph = composite_model.get_descendant_graph()
cpm = composite_model['cpm']
assert (composite_model, cpm) in graph.edges
assert (composite_model, composite_model['Average']) in graph.edges
assert (cpm, cpm['Read']) in graph.edges
assert (cpm, cpm['Pad']) in graph.edges
with pytest.raises(ValueError):
composite_model.get_descendant_graph(in_subwindow=True)
# Subwindow editing
composite_model.edit_in_window()
qtbot.addWidget(composite_model._other_view)
graph = composite_model.get_descendant_graph(in_subwindow=True)
cpm = composite_model._window_nodes['cpm'].model
average = composite_model._window_nodes['Average'].model
assert (composite_model, cpm) in graph.edges
assert (composite_model, average) in graph.edges
assert (cpm, cpm['Read']) in graph.edges
assert (cpm, cpm['Pad']) in graph.edges
composite_model._other_view.close()
# Create outer composite with foobar inside, get_descendant_graph with in_subwindow=True
# when outer is being edited and foobar not must return outer subwindow models and foobar's
# internal models
scene = create_scene(qtbot, composite_model._registry)
inner = scene.create_node(composite_model.__class__)
monkeypatch.setattr(QInputDialog, "getText", lambda *args: ('outer', True))
inner.graphics_object.setSelected(True)
outer = scene.create_composite().model
outer.edit_in_window()
qtbot.addWidget(outer._other_view)
graph = outer.get_descendant_graph(in_subwindow=True)
inner = outer._window_nodes['foobar'].model
assert (outer, inner) in graph.edges
assert (inner, inner['Average']) in graph.edges
assert (inner['cpm'], inner['cpm']['Read']) in graph.edges
outer._other_view.close()
def test_contains_path(self, qtbot, composite_model, nodes):
assert composite_model.contains_path(['Average'])
assert composite_model.contains_path(['cpm'])
assert composite_model.contains_path(['cpm', 'Read'])
assert not composite_model.contains_path(['cpm', 'Read 2'])
assert not composite_model.contains_path(['foo'])
def test_get_model_from_path(self, qtbot, composite_model, nodes):
assert composite_model.get_model_from_path(['cpm', 'Read'])
with pytest.raises(KeyError):
composite_model.get_model_from_path(['foo'])
def test_is_model_inside(self, qtbot, composite_model, nodes):
model = composite_model.get_model_from_path(['cpm'])
assert composite_model.is_model_inside(model)
model = composite_model.get_model_from_path(['cpm', 'Read'])
assert composite_model.is_model_inside(model)
assert not composite_model.is_model_inside(nodes['read_2'].model)
def test_get_path_from_model(self, qtbot, composite_model, nodes):
cpm = composite_model['cpm']
path = composite_model.get_path_from_model(cpm)
assert path == [composite_model, cpm]
path = composite_model.get_path_from_model(cpm['Read'])
assert path == [composite_model, cpm, cpm['Read']]
model = composite_model['cpm']['Read']
path = composite_model.get_path_from_model(model)
assert path == [composite_model, cpm, model]
with pytest.raises(KeyError):
composite_model.get_path_from_model(nodes['read_2'].model)
def test_leaf_paths(self, qtbot, composite_model, nodes):
leaves = composite_model.get_leaf_paths(in_subwindow=False)
cpm = composite_model['cpm']
assert len(leaves) == 3
assert [composite_model, cpm, cpm['Read']] in leaves
assert [composite_model, cpm, cpm['Pad']] in leaves
assert [composite_model, composite_model['Average']] in leaves
def test_set_property_links_model(self, qtbot, link_model, composite_model):
composite_model.property_links_model = link_model
assert composite_model.property_links_model == link_model
# The property links model must be set also for children
assert composite_model['cpm'].property_links_model == link_model
def test_get_outside_port(self, qtbot, composite_model):
# There is one input corresponding to cpm's pad model
cpm = composite_model.get_model_from_path(['cpm'])
pad_index = cpm.get_outside_port('Pad', 'input', 0)[1]
composite_model.get_outside_port('cpm', 'input', pad_index)
# and two outputs: cpm's read and average
read_index = cpm.get_outside_port('Read', 'output', 0)[1]
composite_model.get_outside_port('cpm', 'output', read_index)
composite_model.get_outside_port('Average', 'output', 0)
def test_get_model_and_port_index(self, qtbot, composite_model):
model, index = composite_model.get_model_and_port_index('input', 0)
cpm = composite_model.get_model_from_path(['cpm'])
# There is only one input: Pad. Get it's internal cpm's index and compare with what the
# outer composite object gives.
pad_index = cpm.get_outside_port('Pad', 'input', 0)[1]
assert model == cpm
assert index == pad_index
# There are two output ports, one from cpm's read model and one from average
average = composite_model.get_model_from_path(['Average'])
# Get read index from the cpm inside the composite_model and not from the cpm in the 'nodes'
# fixsture because those are not the same instance and the read output index might be
# different in those two instances because the ports are dictionaries
read_index = cpm.get_outside_port('Read', 'output', 0)[1]
outputs = [composite_model.get_model_and_port_index('output', 0)]
outputs.append(composite_model.get_model_and_port_index('output', 1))
assert (cpm, read_index) in outputs
assert (average, 0) in outputs
def test_embedded_widget(self, qtbot, composite_model):
assert isinstance(composite_model.embedded_widget(), MultiPropertyView)
def test_restore(self, qtbot, composite_model):
state = composite_model.save()
old_value = composite_model['cpm']['Pad']['width']
old_caption = composite_model.caption
visible = not composite_model._view.is_group_visible('cpm')
composite_model['cpm']['Pad']['width'] = old_value + 1
composite_model._view.set_group_visible('cpm', visible)
composite_model.caption = 'Foo'
composite_model.restore(state, restore_caption=False)
assert composite_model['cpm']['Pad']['width'] == old_value
assert composite_model._view.is_group_visible('cpm') == (not visible)
assert composite_model.caption == 'Foo'
composite_model.restore(state, restore_caption=True)
assert composite_model.caption == old_caption
conn = composite_model._connections[0]
assert [[conn.from_unique_name, conn.from_port_index,
conn.to_unique_name, conn.to_port_index]] == state['connections']
def test_restore_links(self, qtbot, nodes):
def check_links(node, link_model):
assert link_model.rowCount() == 1
assert link_model.columnCount() == 3
assert link_model.find_items((node.model['cpm']['Read'], 'number'),
(MODEL_ROLE, PROPERTY_ROLE))
assert link_model.find_items((node.model['cpm']['Pad'], 'height'),
(MODEL_ROLE, PROPERTY_ROLE))
assert link_model.find_items((node.model['Average'], 'number'),
(MODEL_ROLE, PROPERTY_ROLE))
assert not link_model.find_items((node.model['cpm']['Read'], 'height'),
(MODEL_ROLE, PROPERTY_ROLE))
scene, node = make_composite_node_in_scene(qtbot, nodes)
link_model = scene.property_links_model
link_model.add_item(node, node.model['cpm']['Read'], 'number', 0, 0)
link_model.add_item(node, node.model['cpm']['Pad'], 'height', 0, 1)
link_model.add_item(node, node.model['Average'], 'number', 0, 2)
# Set links to the newly created links
node.model._links = node.model.save()['links']
# Link model has to have the exact same entries as before
link_model.clear()
node.model.restore_links(node)
check_links(node, link_model)
# Second time doesn't add the same links twice
node.model.restore_links(node)
check_links(node, link_model)
def test_save(self, qtbot, nodes):
scene, node = make_composite_node_in_scene(qtbot, nodes)
link_model = scene.property_links_model
cpm = node.model['cpm']
link_model.add_item(node, node.model['cpm']['Read'], 'number', 0, 0)
link_model.add_item(node, node.model['cpm']['Pad'], 'height', 0, 1)
link_model.add_item(node, node.model['Average'], 'number', 0, 2)
old_value = node.model['cpm']['Pad']['width']
visible = not node.model._view.is_group_visible('cpm')
node.model['cpm']['Pad']['width'] = old_value + 1
node.model._view.set_group_visible('cpm', visible)
node.model.caption = 'Foo'
state = node.model.save()
cpm_models_state = state['models']['cpm']['model']['models']
assert state['models']['cpm']['visible'] == visible
assert cpm_models_state['Pad']['model']['properties']['width'][0] == old_value + 1
assert state['caption'] == 'Foo'
cpm = node.model.get_model_from_path(['cpm'])
pad_index = cpm.get_outside_port('Pad', 'output', 0)[1]
assert state['connections'] == [['cpm', pad_index, 'Average', 0]]
# Property links
links = link_model.get_model_links([path[-1] for path in node.model.get_leaf_paths()])
links = [[str_path[1:] for str_path in row] for row in links.values()]
saved = node.model.save()['links']
# One row
assert len(saved) == len(links) == 1
# All linked paths must be saved
for str_path in saved[0]:
assert str_path in links[0]
def test_on_connection_created(self, qtbot, composite_model):
composite_model.edit_in_window()
qtbot.addWidget(composite_model._other_view)
for node in composite_model._other_scene.nodes.values():
if node.model.caption == 'cpm':
read_index = node.model.get_outside_port('Read', 'output', 0)[1]
pad_index = node.model.get_outside_port('Pad', 'input', 0)[1]
output_port = node['output'][read_index]
input_port = node['input'][pad_index]
num_connections = len(composite_model._other_scene.connections)
composite_model._other_scene.create_connection(output_port, input_port)
# No new connections allowed
assert len(composite_model._other_scene.connections) == num_connections
composite_model._other_view.close()
def test_on_connection_deleted(self, qtbot, composite_model):
composite_model.edit_in_window()
qtbot.addWidget(composite_model._other_view)
num_connections = len(composite_model._other_scene.connections)
composite_model._other_scene.delete_connection(composite_model._other_scene.connections[0])
# No connection deletions
assert len(composite_model._other_scene.connections) == num_connections
composite_model._other_view.close()
def test_double_clicked(self, qtbot, composite_model):
composite_model.double_clicked(None)
qtbot.addWidget(composite_model._other_view)
assert composite_model.is_editing and composite_model._other_view is not None
def test_on_other_scene_double_clicked(self, qtbot, composite_model):
composite_model.double_clicked(None)
qtbot.addWidget(composite_model._other_view)
for node in composite_model._other_scene.nodes.values():
if node.model.caption == 'cpm':
node.model.double_clicked(composite_model._other_view)
qtbot.addWidget(node.model._other_view)
assert composite_model.is_editing and composite_model._other_view is not None
break
def test_expand_into_graph(self, qtbot, composite_model):
import networkx as nx
graph = nx.MultiDiGraph()
composite_model.expand_into_graph(graph)
src, dst, ports = list(graph.edges.data())[0]
conn = composite_model._connections[0]
gt = [conn.from_unique_name, conn.from_port_index, conn.to_unique_name, conn.to_port_index]
conn_graph = [src.caption, ports['output'], dst.caption, ports['input']]
assert conn_graph == gt
def test_add_slave_links(self, qtbot, monkeypatch, nodes):
def crosscheck(model, root_model, property_name, link_model):
key = (model, property_name)
root_key = (root_model, property_name)
assert link_model._silent[key] == root_key
assert key in link_model._slaves[root_key]
scene, node = make_composite_node_in_scene(qtbot, nodes)
link_model = scene.property_links_model
link_model.add_item(node, node.model['cpm']['Read'], 'number', 0, 0)
link_model.add_item(node, node.model['cpm']['Pad'], 'height', 0, 1)
link_model.add_item(node, node.model['Average'], 'number', 0, 2)
# Not being edited, nothing registered
node.model.add_slave_links()
assert link_model._silent == {}
node.model.edit_in_window()
qtbot.addWidget(node.model._other_view)
# Standard editing setup
assert hasattr(node.model, '_other_scene')
assert hasattr(node.model, '_other_view')
assert not node.model._other_scene.allow_node_creation
assert not node.model._other_scene.allow_node_deletion
# Test foobar's subwindow, registering model is cpm and its internal models must be linked
crosscheck(node.model._window_nodes['cpm'].model['Read'],
node.model['cpm']['Read'], 'number', link_model)
crosscheck(node.model._window_nodes['cpm'].model['Pad'],
node.model['cpm']['Pad'], 'height', link_model)
crosscheck(node.model._window_nodes['Average'].model,
node.model['Average'], 'number', link_model)
# Test foobar's subwindow and cpm's subwindow, cpm and also its models in the subwindow must
# be linked
cpm = node.model._window_nodes['cpm'].model
cpm.edit_in_window()
qtbot.addWidget(cpm._other_view)
assert cpm.window_parent == node.model
crosscheck(cpm._window_nodes['Read'].model,
node.model['cpm']['Read'], 'number', link_model)
crosscheck(cpm._window_nodes['Pad'].model,
node.model['cpm']['Pad'], 'height', link_model)
# Add one more composite layer, outer->foobar->cpm->Model, both registering model and its
# window_parent must be jinked
node.model._other_view.close()
assert cpm._other_view is None
monkeypatch.setattr(QInputDialog, "getText", lambda *args: ('outermost', True))
node.graphics_object.setSelected(True)
outer = scene.create_composite()
outer.model.edit_in_window()
qtbot.addWidget(outer.model._other_view)
node_sub = outer.model._window_nodes['foobar']
node_sub.model.edit_in_window()
qtbot.addWidget(node_sub.model._other_view)
cpm = node_sub.model._window_nodes['cpm'].model
cpm.edit_in_window()
qtbot.addWidget(cpm._other_view)
crosscheck(node_sub.model._window_nodes['cpm'].model['Read'],
outer.model['foobar']['cpm']['Read'], 'number', link_model)
crosscheck(node_sub.model._window_nodes['cpm'].model['Pad'],
outer.model['foobar']['cpm']['Pad'], 'height', link_model)
crosscheck(node_sub.model._window_nodes['Average'].model,
outer.model['foobar']['Average'], 'number', link_model)
crosscheck(cpm._window_nodes['Read'].model,
outer.model['foobar']['cpm']['Read'], 'number', link_model)
crosscheck(cpm._window_nodes['Pad'].model,
outer.model['foobar']['cpm']['Pad'], 'height', link_model)
cpm._other_view.close()
node_sub.model._other_view.close()
outer.model._other_view.close()
def test_edit_in_window(self, qtbot, nodes):
composite_model = nodes['cpm'].model
link_model = composite_model.property_links_model
populate_link_model(link_model, nodes)
composite_model.edit_in_window()
qtbot.addWidget(composite_model._other_view)
assert hasattr(composite_model, '_other_scene')
assert hasattr(composite_model, '_other_view')
assert not composite_model._other_scene.allow_node_creation
assert not composite_model._other_scene.allow_node_deletion
# Silent must have been added with root cpm's read model
assert (list(link_model._slaves.keys())[0]
== (composite_model.get_model_from_path(['Read']), 'y'))
# Subcomposites must link to their parent models
scene, node = make_composite_node_in_scene(qtbot, nodes)
node.model.edit_in_window()
qtbot.addWidget(node.model._other_view)
assert node.model._window_nodes['cpm'].model.window_parent == node.model
node.model._other_view.close()
composite_model._other_view.close()
def test_view_close_event(self, qtbot, nodes):
composite_model = nodes['cpm'].model
link_model = composite_model.property_links_model
populate_link_model(link_model, nodes)
composite_model.edit_in_window()
qtbot.addWidget(composite_model._other_view)
for node in composite_model._other_scene.nodes.values():
if node.model.caption == 'Read':
widget = node.model.embedded_widget()._properties['y'].view_item.widget
qtbot.addWidget(node.model.embedded_widget())
qtbot.keyClicks(widget, '11')
else:
# Pad
node.model['width'] += 10
# Linked models must be updated immediately
assert composite_model['Read']['y'] == 11
assert nodes['read'].model['number'] == 11
assert nodes['read_2'].model['height'] == 11
composite_model._other_view.close()
# Original models in the composite must be updated after close
assert composite_model['Pad']['width'] == 10
# Silent model must be removed (it was the only one, so test for {} is sufficient)
assert link_model._slaves == {}
assert link_model._silent == {}
def test_expand_into_scene(self, qtbot, monkeypatch):
def get_int(*args, **kwargs):
return self.get_int_return
monkeypatch.setattr(QInputDialog, "getInt", get_int)
nodes = {}
registry = get_filled_registry()
scene = create_scene(qtbot, registry)
# Composite node
for name in ['read', 'pad']:
model_cls = registry.create(name)
node = scene.create_node(model_cls)
node.graphics_object.setSelected(True)
monkeypatch.setattr(QInputDialog, "getText", lambda *args: ('cpm', True))
nodes['cpm'] = scene.create_composite()
nodes['cpm'].graphics_object.setSelected(False)
model_cls = registry.create('average')
nodes['average'] = scene.create_node(model_cls)
self.get_int_return = (2, True)
model_cls = registry.create('retrieve_phase')
nodes['retrieve_phase'] = scene.create_node(model_cls)
# Add null node to create an outside connection
null_cls = registry.create('null')
null_node = scene.create_node(null_cls)
# Make a property link
scene.property_links_model.add_item(nodes['cpm'], nodes['cpm'].model['Read'],
'number', 0, 0)
scene.property_links_model.add_item(nodes['cpm'], nodes['cpm'].model['Pad'], 'width', 0, 1)
# Export composite and reload it so that it remembers the links (important for testing of
# adding property link duplicates)
cpm_cls_with_links = get_composite_model_classes_from_json(nodes['cpm'].model.save())[0]
registry.register_model(cpm_cls_with_links, category='Composites', registry=registry)
scene.remove_node(nodes['cpm'])
nodes['cpm'] = scene.create_node(registry.create('cpm'))
# Outer composite node has inside: read, pad, average; pad and average are connected
# read and pad are encapsulated in an internal composite cpm
pad_index = nodes['cpm'].model.get_outside_port('Pad', 'output', 0)[1]
scene.create_connection(nodes['cpm']['output'][pad_index], nodes['average']['input'][0])
nodes['cpm'].graphics_object.setSelected(True)
nodes['average'].graphics_object.setSelected(True)
nodes['retrieve_phase'].graphics_object.setSelected(True)
monkeypatch.setattr(QInputDialog, "getText", lambda *args: ('foobar', True))
scene.create_composite()
composite_node = scene.selected_nodes()[0]
composite_model = composite_node.model
# Create outside connection from outer composite's average to null
port_null = null_node['input'][0]
# average_index = nodes['cpm'].model.get_outside_port('Pad', 'output', 0)[1]
# Get the average index dynamically because it might be mapped to a different output port
# every time (reader in cpm makes another output)
average_index = composite_model.get_outside_port('Average', 'output', 0)[1]
port_composite = composite_node['output'][average_index]
scene.create_connection(port_composite, port_null)
# Change some property to see if it persists after expansion
composite_model['cpm']['Read']['number'] = 123
# Make sure the nested num-inputs takes effect, i.e. QInputDialog.getInt invocation must
# fail the test
self.get_int_return = (None, False)
composite_model.expand_into_scene(scene, composite_node)
# Nodes must be there
assert (set([node.model.caption for node in scene.nodes.values()])
== set(['Null', 'Average', 'Retrieve Phase', 'cpm']))
# num-inputs took effect
for node in scene.nodes.values():
if node.model.caption == 'Retrieve Phase':
assert node.model.num_ports['input'] == 2
break
# Changed properties must be there
for node in scene.nodes.values():
if node.model.caption == 'cpm':
assert node.model['Read']['number'] == 123
break
# Connections must be preserved
for connection in scene.connections:
if connection.get_node('output').model.caption == 'cpm':
# Internal composite connection Pad -> Average must be there
assert connection.get_node('input').model.caption == 'Average'
cpm_index = connection.get_port_index('output')
cpm_model = connection.get_node('output').model
assert cpm_model.get_model_and_port_index('output', cpm_index)[0].caption == 'Pad'
else:
# Outside connection Average -> Null must be there
assert connection.get_node('input').model.caption == 'Null'
# Property links must be there
assert scene.property_links_model.rowCount() == 1
# Original composite node must be gone
assert composite_node not in scene.nodes.values()
def test_get_composite_model_class(qtbot, nodes):
model_cls = make_composite_model_class(nodes)
with pytest.raises(AttributeError):
# Registry must be provided
model_cls()
# Name must be provided
with pytest.raises(UfoModelError):
make_composite_model_class(nodes, name='')
with pytest.raises(UfoModelError):
make_composite_model_class(nodes, name=None)
class TestUfoGeneralBackprojectModel:
def test_init(self, general_backproject):
assert general_backproject.num_ports['input'] == 1
assert general_backproject.num_ports['output'] == 1
assert general_backproject.needs_fixed_scheduler is True
assert general_backproject.can_split_gpu_work is True
def test_make_properties(self, general_backproject):
props = general_backproject.make_properties()
assert 'slice-memory-coeff' in props
def test_split_gpu_work(self, general_backproject):
from gi.repository import Ufo
resources = Ufo.Resources()
gpus = resources.get_gpu_nodes()
general_backproject['x-region'] = [-100., 100., 1.]
general_backproject['y-region'] = [-100., 100., 1.]
general_backproject['region'] = [-100., 100., 1.]
if gpus:
# Normal operation
assert general_backproject.split_gpu_work(gpus)
# Wrong input
general_backproject['x-region'] = [-100., -200., 1.]
with pytest.raises(UfoModelError):
general_backproject.split_gpu_work(gpus)
general_backproject['x-region'] = [-100., 100., 1.]
general_backproject['y-region'] = [-100., -200., 1.]
with pytest.raises(UfoModelError):
general_backproject.split_gpu_work(gpus)
general_backproject['y-region'] = [-100., 100., 1.]
general_backproject['region'] = [-100., -200., 1.]
with pytest.raises(UfoModelError):
general_backproject.split_gpu_work(gpus)
general_backproject['region'] = [-100., 100., 1.]
def test_create_ufo_task(self, general_backproject):
general_backproject['region'] = [-100., 100., 1.]
ufo_task = general_backproject.create_ufo_task(region=None)
assert ufo_task.props.region == pytest.approx(general_backproject['region'])
ufo_task = general_backproject.create_ufo_task(region=[-10., 10., 1.])
assert ufo_task.props.region == pytest.approx([-10., 10., 1.])
class TestUfoReadModel:
def test_init(self, read_model):
assert read_model.num_ports['input'] == 0
assert read_model.num_ports['output'] == 1
def test_double_clicked(self, qtbot, monkeypatch, read_model):
from tofu.flow.filedirdialog import FileDirDialog
monkeypatch.setattr(FileDirDialog, "exec_", lambda *args: 1)
monkeypatch.setattr(FileDirDialog, "selectedFiles", lambda *args: ['foobarbaz'])
read_model.double_clicked(None)
assert read_model['path'] == 'foobarbaz'
class TestUfoVaryingInputModel:
def test_init(self, qtbot, monkeypatch):
def get_int(*args, **kwargs):
self.called = True
return (1, True)
# No number of inputs specified, dialog needs to pop up
self.called = False
monkeypatch.setattr(QInputDialog, 'getInt', get_int)
model = UfoVaryingInputModel('opencl', num_inputs=None)
qtbot.addWidget(model.embedded_widget())
assert self.called
assert model.num_ports['input'] == 1
# e.g. opencl task can have multiple inputs
model = UfoVaryingInputModel('opencl', num_inputs=4)
qtbot.addWidget(model.embedded_widget())
assert model.num_ports['input'] == 4
assert len(model.data_type['input']) == 4
assert len(model.port_caption['input']) == 4
assert len(model.port_caption_visible['input']) == 4
def test_save(self, qtbot):
model = UfoVaryingInputModel('opencl', num_inputs=4)
qtbot.addWidget(model.embedded_widget())
assert model.save()['num-inputs'] == 4
class TestUfoRetrievePhaseModel:
def test_distance_input(self, qtbot):
model = UfoRetrievePhaseModel(num_inputs=4)
qtbot.addWidget(model.embedded_widget())
validator = model._view._properties['distance'].view_item.widget.validator()
# Validator accepts only 4 values
assert validator.validate('1,2,3,4', 0)[0] == QValidator.Acceptable
assert validator.validate('1,2,3', 0)[0] == QValidator.Intermediate
assert validator.validate('1,2,3,4,5', 0)[0] == QValidator.Invalid
def test_multidistance_fixed_method(self, qtbot):
def check(num_inputs):
model = UfoRetrievePhaseModel(num_inputs=num_inputs)
qtbot.addWidget(model.embedded_widget())
enabled = num_inputs == 1
assert model._view._properties['method'].view_item.widget.isEnabled() == enabled
if not enabled:
assert model['method'] == 'ctf_multidistance'
assert model._view._properties['distance-x'].view_item.widget.isEnabled() == enabled
assert model._view._properties['distance-y'].view_item.widget.isEnabled() == enabled
check(1)
check(2)
class TestUfoWriteModel:
def test_init(self, write_model):
assert write_model.num_ports['input'] == 1
assert write_model.num_ports['output'] == 0
def test_double_clicked(self, monkeypatch, write_model):
monkeypatch.setattr(QFileDialog, "getSaveFileName", lambda *args: ('foobarbaz', None))
write_model.double_clicked(None)
assert write_model['filename'] == 'foobarbaz'
def test_expects_multiple_inputs(self, write_model):
write_model['filename'] = 'foo{region}bar'
assert write_model.expects_multiple_inputs
write_model['filename'] = 'foobar'
assert not write_model.expects_multiple_inputs
def test_setup_ufo_task(self, write_model):
write_model['filename'] = '{region}'
# Must pass
ufo_task = write_model.create_ufo_task(region=[0, 1, 1])
# Must fail
with pytest.raises(UfoModelError):
write_model.create_ufo_task(region=None)
assert ufo_task.props.filename == '0'
write_model['filename'] = 'foo.tif'
# Must pass
ufo_task = write_model.create_ufo_task(region=None)
# Must fail
with pytest.raises(UfoModelError):
write_model.create_ufo_task(region=[0, 1, 1])
assert ufo_task.props.filename == 'foo.tif'
class TestUfoMemoryOutModel:
def test_init(self, memory_out_model):
assert memory_out_model.num_ports['input'] == 1
assert memory_out_model.num_ports['output'] == 1
def test_expects_multiple_inputs(self, memory_out_model):
memory_out_model['number'] = '{region}'
assert memory_out_model.expects_multiple_inputs
memory_out_model['number'] = '1'
assert not memory_out_model.expects_multiple_inputs
def test_make_properties(self, memory_out_model):
prop_names = {'width', 'height', 'depth', 'number'}
assert prop_names == memory_out_model.make_properties().keys()
def test_out_data(self, monkeypatch, memory_out_model):
def slot(port_index):
self.num_called += 1
self.data = memory_out_model.out_data(port_index)
self.num_called = 0
memory_out_model['number'] = 10
shape = (int(memory_out_model['number']),
memory_out_model['height'],
memory_out_model['width'])
memory_out_model.create_ufo_task()
batch = memory_out_model._batches[0]
memory_out_model.data_updated.connect(slot)
batch.data[:] = 3
assert len(memory_out_model._batches) == 1
assert batch.data.shape == shape
for i in range(shape[0]):
batch._on_processed(None)
# Called once per 3D array
assert self.num_called == 1
# out_data has been set to the batch ouput
np.testing.assert_almost_equal(self.data, 3)
# Original data must have been freed
assert memory_out_model._batches == [None]
memory_out_model.reset_batches()
# Multiple inputs
def slot(port_index):
# Append the first item in the current result
self.called.append(memory_out_model.out_data(port_index)[0, 0, 0])
self.called = []
memory_out_model.data_updated.connect(slot)
memory_out_model['number'] = '{region}'
# Two parallel batches of four regions each
for j in range(2):
for i in range(4):
memory_out_model.create_ufo_task(region=[0, 10, 1])
# Set batch data to its linearized index to make checking easy
memory_out_model._batches[4 * j + i].data[:] = 4 * j + i
# Out of order processing
for batch_id in np.array([2, 0, 1, 3], dtype=np.int) + (4 * j):
for e in range(10):
memory_out_model._batches[batch_id]._on_processed(None)
# All regions in the current paralell batch must have been processed
assert memory_out_model._waiting_list == []
# Result must be in order
np.testing.assert_almost_equal(self.called, np.arange(8))
# Original data must have been freed
assert memory_out_model._batches == [None] * 8
def test_reset_batches(self, memory_out_model):
memory_out_model.reset_batches()
assert memory_out_model._batches == []
assert memory_out_model._waiting_list == []
assert memory_out_model._expecting_id == 0
assert memory_out_model._current_data is None
def test_setup_ufo_task(self, memory_out_model):
memory_out_model['number'] = '{region}'
# Must pass
memory_out_model.create_ufo_task(region=[0, 100, 1])
memory_out_model.create_ufo_task(region=[100, 200, 1])
assert len(memory_out_model._batches) == 2
# Must fail
with pytest.raises(UfoModelError):
memory_out_model.create_ufo_task(region=None)
memory_out_model.reset_batches()
memory_out_model['number'] = '100'
# Must pass
memory_out_model.create_ufo_task(region=None)
# Must fail
with pytest.raises(UfoModelError):
memory_out_model.create_ufo_task(region=[0, 100, 1])
assert len(memory_out_model._batches) == 1
class TestImageViewerModel:
def test_init(self, image_viewer_model):
assert image_viewer_model.num_ports['input'] == 1
assert image_viewer_model.num_ports['output'] == 0
def test_double_clicked(self, qtbot, image_viewer_model):
image_viewer_model.double_clicked(None)
# No images, no pop up
assert image_viewer_model._widget._pg_window is None
image_viewer_model._widget.images = np.arange(1000).reshape(10, 10, 10)
image_viewer_model.double_clicked(None)
assert image_viewer_model._widget._pg_window.isVisible()
qtbot.addWidget(image_viewer_model._widget._pg_window)
# User closes, must re-open
image_viewer_model._widget._pg_window.close()
image_viewer_model.double_clicked(None)
assert image_viewer_model._widget._pg_window.isVisible()
def test_set_in_data(self, image_viewer_model):
images = np.arange(1000).reshape(10, 10, 10)
image_viewer_model.set_in_data(images, None)
assert image_viewer_model._widget.images.shape == images.shape
image_viewer_model.set_in_data(images, None)
assert image_viewer_model._widget.images.shape == (20,) + images.shape[1:]
# Images cannot be appended after reset is called, they must be set
image_viewer_model.reset_batches()
image_viewer_model.set_in_data(images, None)
assert image_viewer_model._widget.images.shape == images.shape
def test_reset_batches(self, image_viewer_model):
image_viewer_model.reset_batches()
assert image_viewer_model._reset
def test_get_ufo_model_classes():
# All
classes = list(get_ufo_model_classes())
assert classes
# Blacklist
assert 'read' not in [cls.name for cls in classes]
# Selection
assert len(list(get_ufo_model_classes(names=['pad']))) == 1
def test_get_composite_model_classes_from_json(qtbot, composite_model):
classes = get_composite_model_classes_from_json(composite_model.save())
# First must be the bottom class, top class comes last
assert [cls.name for cls in classes] == ['cpm', 'foobar']
def test_get_composite_model_classes():
# Just make sure this runs and the result is not empty
assert get_composite_model_classes()
tofu-0.12.0/tofu/tests/test_flow_propertylinksmodels.py 0000664 0000000 0000000 00000047263 14237137211 0023474 0 ustar 00root root 0000000 0000000 import pytest
from qtpy.QtCore import QByteArray, QMimeData, QModelIndex
from tofu.flow.propertylinksmodels import _get_string_path
from tofu.flow.propertylinkswidget import _encode_mime_data
from tofu.flow.util import MODEL_ROLE, NODE_ROLE, PROPERTY_ROLE
from tofu.tests.flow_util import get_index_from_treemodel, populate_link_model
def setup_silent(link_model, nodes):
read = nodes['read']
read_2 = nodes['read_2']
composite = nodes['cpm']
orig_key = (read.model, 'number')
link_model.add_item(read, read.model, 'number', -1, -1)
link_model.add_silent(composite.model['Read'], 'number', orig_key[0], orig_key[1])
link_model.add_silent(read_2.model, 'height', orig_key[0], orig_key[1])
# Put to 0 to make sure we are not lucky when checking if the links work
composite.model['Read']['number'] = 0
read_2.model['height'] = 0
return orig_key
class TestNodeTreeModel:
def test_add_node(self, qtbot, node_model, nodes):
# Unsupported model type not added
node_model.add_node(nodes['image_viewer'])
assert node_model.rowCount() == 0
# Supported model type (composite is handled in test_add_node)
node_model.add_node(nodes['read'])
assert node_model.rowCount() == 1
# Composite
node_model.add_node(nodes['cpm'])
item = node_model.findItems('cpm')[0]
# Model contains composite node
assert item.data(role=NODE_ROLE) == nodes['cpm']
# and it's children
assert item.child(0).data(role=MODEL_ROLE) == nodes['cpm'].model['Pad']
assert item.child(1).data(role=MODEL_ROLE) == nodes['cpm'].model['Read']
# and their properties
assert item.child(0).child(0).text() == sorted(nodes['cpm'].model['Pad'])[0]
def test_remove_node(self, qtbot, node_model, nodes):
node_model.add_node(nodes['cpm'])
assert node_model.rowCount() == 1
node_model.remove_node(nodes['cpm'])
assert node_model.rowCount() == 0
def test_set_nodes(self, qtbot, node_model, nodes):
names = ['cpm', 'read']
subset = [nodes[key] for key in names]
node_model.set_nodes(subset)
for (i, key) in enumerate(names):
assert node_model.item(i).data(role=NODE_ROLE) == nodes[key]
def test_clear(self, qtbot, node_model, nodes):
node_model.set_nodes(nodes.values())
assert node_model.rowCount() > 0
assert node_model.columnCount() > 0
node_model.clear()
assert node_model.rowCount() == 0
assert node_model.columnCount() == 0
class TestPropertyLinksModel:
def test_add_item(self, qtbot, link_model, nodes):
read = nodes['read']
composite = nodes['cpm']
composite.model.property_links_model = link_model
# Put to 0 to make sure we are not lucky below when checking if the links work
composite.model['Read']['number'] = 0
# Items must be added
link_model.add_item(read, read.model, 'number', -1, -1)
item = link_model.item(0, 0)
assert item.data(role=NODE_ROLE) == read
assert item.data(role=MODEL_ROLE) == read.model
assert item.data(role=PROPERTY_ROLE) == 'number'
link_model.add_item(composite, composite.model['Read'], 'number', 0, -1)
item = link_model.item(0, 1)
assert item.data(role=NODE_ROLE) == composite
assert item.data(role=MODEL_ROLE) == composite.model['Read']
assert item.data(role=PROPERTY_ROLE) == 'number'
# Can't add one item twice
with pytest.raises(ValueError):
link_model.add_item(read, read.model, 'number', -1, -1)
# Properties must be linked
read.model['number'] = 100
read.model.property_changed.emit(read.model, 'number', read.model['number'])
assert composite.model['Read']['number'] == read.model['number']
# When composite is being added, make sure the slave links are set up
link_model.remove_item(link_model.find_items([composite], [NODE_ROLE])[0])
composite.model.edit_in_window()
qtbot.addWidget(composite.model._other_view)
link_model.add_item(composite, composite.model['Read'], 'number', 0, -1)
key = (composite.model._window_nodes['Read'].model, 'number')
root_key = (composite.model['Read'], 'number')
assert link_model._slaves[root_key] == [key]
assert link_model._silent[key] == root_key
def test_remove_item(self, qtbot, link_model, nodes):
read = nodes['read']
read_2 = nodes['read_2']
composite = nodes['cpm']
link_model.add_item(read, read.model, 'number', -1, -1)
link_model.add_item(read_2, read_2.model, 'number', 0, -1)
link_model.add_silent(composite.model['Read'], 'number', read.model, 'number')
# Properties must be connected at first
read.model['number'] = 100
read.model.property_changed.emit(read.model, 'number', read.model['number'])
assert read_2.model['number'] == read.model['number']
link_model.remove_item(link_model.indexFromItem(link_model.item(0, 0)))
assert link_model.item(0, 0) is None
assert link_model._silent == {}
assert link_model._slaves == {}
# Properties must be disconnected after removal
read.model['number'] = 0
read.model.property_changed.emit(read.model, 'number', read.model['number'])
# read_2 still at the old 100
assert read_2.model['number'] == 100
def test_contains(self, qtbot, link_model, nodes):
composite = nodes['cpm']
link_model.add_item(composite, composite.model['Read'], 'number', 0, -1)
assert link_model.item(0, 0).text() in link_model
assert 'foo' not in link_model
def test_clear(self, qtbot, link_model, nodes):
read = nodes['read']
read_2 = nodes['read_2']
composite = nodes['cpm']
link_model.add_item(read, read.model, 'number', -1, -1)
link_model.add_item(read_2, read_2.model, 'number', 0, -1)
link_model.add_silent(composite.model['Read'], 'number', read.model, 'number')
link_model.clear()
assert link_model.rowCount() == 0
assert link_model.columnCount() == 0
assert link_model._silent == {}
assert link_model._slaves == {}
def test_find_items(self, qtbot, link_model, nodes):
read = nodes['read']
read_2 = nodes['read_2']
# Empty model
assert link_model.find_items([read.model], [MODEL_ROLE]) == []
link_model.add_item(read, read.model, 'number', -1, -1)
# Not inside
assert link_model.find_items([read_2.model], [MODEL_ROLE]) == []
# Inside
assert (link_model.find_items([read.model], [MODEL_ROLE])[0].data(role=MODEL_ROLE)
== read.model)
# Model not inside, property not inside
assert link_model.find_items((read_2.model, 'height'), (MODEL_ROLE, PROPERTY_ROLE)) == []
# Model inside, property not inside
assert link_model.find_items((read.model, 'height'), (MODEL_ROLE, PROPERTY_ROLE)) == []
# Model not inside, property inside
assert link_model.find_items((read_2.model, 'number'), (MODEL_ROLE, PROPERTY_ROLE)) == []
# Model inside, property inside
item = link_model.find_items((read.model, 'number'), (MODEL_ROLE, PROPERTY_ROLE))[0]
assert item.data(role=MODEL_ROLE) == read.model
assert item.data(role=PROPERTY_ROLE) == 'number'
def test_get_model_links(sef, qtbot, link_model, nodes):
populate_link_model(link_model, nodes)
assert link_model.get_model_links(nodes['read_3'].model) == {}
links = link_model.get_model_links([nodes['read'].model,
nodes['read_2'].model,
nodes['cpm'].model['Read']])
links = list(links.values())
# Just one row
assert len(links) == 1
# Three items in that row
assert len(links[0]) == 3
assert [nodes['read'].model.caption, 'number'] in links[0]
assert [nodes['read_2'].model.caption, 'height'] in links[0]
path = nodes['cpm'].model.get_path_from_model(nodes['cpm'].model['Read'])
str_path = [model.caption for model in path] + ['y']
assert str_path in links[0]
def test_get_root_model(self, qtbot, link_model, nodes):
read = nodes['read']
composite = nodes['cpm']
link_model.add_item(read, read.model, 'number', -1, -1)
# Not inside
assert link_model.get_root_model(nodes['read_2'].model) is None
# Directly inside
assert link_model.get_root_model(read.model) == read.model
# Indirectly inside via silent
link_model.add_silent(composite.model['Read'], 'number', read.model, 'number')
assert link_model.get_root_model(composite.model['Read']) == read.model
def test_get_model_properties(self, qtbot, link_model, nodes):
read = nodes['read']
link_model.add_item(read, read.model, 'number', -1, -1)
link_model.add_item(read, read.model, 'height', -1, -1)
# Empty
assert link_model.get_model_properties(nodes['read_2'].model) == []
# Multiple
assert set(link_model.get_model_properties(read.model)) == set(['number', 'height'])
def test_add_silent(self, qtbot, link_model, nodes):
read = nodes['read']
read_2 = nodes['read_2']
composite = nodes['cpm']
orig_key = setup_silent(link_model, nodes)
# orig model not inside
with pytest.raises(ValueError):
link_model.add_silent(composite.model['Read'], 'height',
nodes['read_3'].model, 'number')
# source property not inside
with pytest.raises(ValueError):
link_model.add_silent(composite.model['Read'], 'height', read.model, 'height')
# Links inside
assert len(link_model._slaves[orig_key]) == 2
key = (composite.model['Read'], 'number')
assert link_model._silent[key] == orig_key
assert key in link_model._slaves[orig_key]
key = (read_2.model, 'height')
assert link_model._silent[key] == orig_key
assert key in link_model._slaves[orig_key]
# Properties conected
read.model['number'] = 100
read.model.property_changed.emit(read.model, 'number', read.model['number'])
assert composite.model['Read']['number'] == read.model['number']
assert read_2.model['height'] == read.model['number']
def test_remove_silent(self, qtbot, link_model, nodes):
read = nodes['read']
read_2 = nodes['read_2']
composite = nodes['cpm']
orig_key = setup_silent(link_model, nodes)
key = (composite.model['Read'], 'number')
link_model.remove_silent(*key)
assert key not in link_model._silent
# Silent link disconected
read.model['number'] = 100
read.model.property_changed.emit(read.model, 'number', read.model['number'])
assert composite.model['Read']['number'] == 0
assert read_2.model['height'] == read.model['number']
# No more slaves, remove the original key as well
key = (nodes['read_2'].model, 'height')
link_model.remove_silent(*key)
assert orig_key not in link_model._slaves
def test_replace_item(self, qtbot, link_model, nodes):
read = nodes['read']
read_2 = nodes['read_2']
composite = nodes['cpm']
orig_key = setup_silent(link_model, nodes)
replacer = nodes['read_3']
item = link_model.find_items(orig_key, (MODEL_ROLE, PROPERTY_ROLE))[0]
(row, column) = item.row(), item.column()
link_model.replace_item(replacer, replacer.model, orig_key[0])
new_item = link_model.item(row, column)
assert new_item.data(role=MODEL_ROLE) == replacer.model
# Silent links re-connected
# This must have no effect on silent models
read.model['number'] = 100
read.model.property_changed.emit(read.model, 'number', read.model['number'])
assert composite.model['Read']['number'] == 0
assert read_2.model['height'] == 0
# This must change silent models' properties
replacer.model['number'] = 100
replacer.model.property_changed.emit(replacer.model, 'number', replacer.model['number'])
assert composite.model['Read']['number'] == replacer.model['number']
assert read_2.model['height'] == replacer.model['number']
def test_on_node_rows_about_to_be_removed(self, qtbot, link_model, node_model, nodes):
read = nodes['read']
read_2 = nodes['read_2']
read_3 = nodes['read_3']
node_model.add_node(read)
node_model.add_node(read_2)
node_model.add_node(read_3)
link_model.add_item(read, read.model, 'number', -1, -1)
link_model.add_item(read_2, read_2.model, 'number', 0, -1)
link_model.add_item(read_3, read_3.model, 'number', -1, -1)
# Remove one
node_model.removeRow(0)
assert link_model.find_items([read.model], [MODEL_ROLE]) == []
# Remove all
node_model.clear()
assert link_model.find_items([read_2.model], [MODEL_ROLE]) == []
assert link_model.find_items([read_3.model], [MODEL_ROLE]) == []
def test_canDropMimeData(self, qtbot, link_model, node_model, nodes):
read = nodes['read']
read_2 = nodes['read_2']
node_model.add_node(read)
node_model.add_node(read_2)
# Incompatible QMimeData
data = QMimeData()
data.setData('application/x-foobar', QByteArray())
assert not link_model.canDropMimeData(data, None, -1, -1, QModelIndex())
# No parent
index = get_index_from_treemodel(node_model, 0, 'number')
data = _encode_mime_data(index)
assert link_model.canDropMimeData(data, None, -1, -1, QModelIndex())
link_model.add_item(read, read.model, 'number', -1, -1)
assert not link_model.canDropMimeData(data, None, -1, -1, QModelIndex())
# On parent
# Compatible property type
index = get_index_from_treemodel(node_model, 1, 'number')
data = _encode_mime_data(index)
parent = link_model.indexFromItem(link_model.item(0, 0))
assert link_model.canDropMimeData(data, None, 0, 0, parent)
# Incompatible property type
index = get_index_from_treemodel(node_model, 1, 'path')
data = _encode_mime_data(index)
parent = link_model.indexFromItem(link_model.item(0, 0))
assert not link_model.canDropMimeData(data, None, 0, 0, parent)
def test_dropMimeData(self, qtbot, link_model, node_model, nodes):
read = nodes['read']
read_2 = nodes['read_2']
node_model.add_node(read)
node_model.add_node(read_2)
# No parent
index = get_index_from_treemodel(node_model, 0, 'number')
data = _encode_mime_data(index)
link_model.dropMimeData(data, None, -1, -1, QModelIndex())
item = link_model.item(0, 0)
assert item.data(role=NODE_ROLE) == read
assert item.data(role=MODEL_ROLE) == read.model
assert item.data(role=PROPERTY_ROLE) == 'number'
# On parent
index = get_index_from_treemodel(node_model, 1, 'number')
data = _encode_mime_data(index)
parent = link_model.indexFromItem(link_model.item(0, 0))
link_model.dropMimeData(data, None, -1, -1, parent)
item = link_model.item(0, 1)
assert item.data(role=NODE_ROLE) == read_2
assert item.data(role=MODEL_ROLE) == read_2.model
assert item.data(role=PROPERTY_ROLE) == 'number'
def test_save(self, qtbot, link_model, nodes):
records = populate_link_model(link_model, nodes)
for (i, (node_id, str_path)) in enumerate(link_model.save()[0]):
assert node_id == records[i][0].id
path = _get_string_path(records[i][0], records[i][1], records[i][2])
assert str_path == path
def test_restore(self, qtbot, link_model, nodes):
records = populate_link_model(link_model, nodes)
state = link_model.save()
link_model.clear()
# Add new item
read_3 = nodes['read_3']
link_model.add_item(read_3, read_3.model, 'number', -1, -1)
link_model.restore(state, {node.id: node for node in nodes.values()})
assert link_model.columnCount() == 3
for column in range(link_model.columnCount()):
item = link_model.item(0, column)
assert item.data(role=NODE_ROLE) == records[column][0]
assert item.data(role=MODEL_ROLE) == records[column][1]
assert item.data(role=PROPERTY_ROLE) == records[column][2]
# Restore must clear whatever is inside
assert link_model.find_items([read_3.model], [MODEL_ROLE]) == []
def test_compact(self, qtbot, link_model, nodes):
read = nodes['read']
read_2 = nodes['read_2']
read_3 = nodes['read_3']
read_4 = nodes['read_4']
def populate():
link_model.add_item(read, read.model, 'number', 0, 0)
link_model.add_item(read_2, read_2.model, 'number', 0, 1)
link_model.add_item(read_3, read_3.model, 'number', 1, 0)
link_model.add_item(read_4, read_4.model, 'number', 1, 1)
def check(row_count, column_count):
assert link_model.rowCount() == row_count
assert link_model.columnCount() == column_count
populate()
link_model.remove_item(link_model.indexFromItem(link_model.item(0, 1)))
link_model.compact()
check(2, 2)
link_model.clear()
# Shift item to the left to an unused cell
populate()
link_model.remove_item(link_model.indexFromItem(link_model.item(0, 0)))
link_model.compact()
assert link_model.item(0, 0).data(role=NODE_ROLE) == read_2
check(2, 2)
# Nothing in the row, remove it
link_model.remove_item(link_model.indexFromItem(link_model.item(0, 0)))
link_model.compact()
check(1, 2)
# Remove column 0 and shift 1st column to the left
link_model.clear()
populate()
link_model.remove_item(link_model.indexFromItem(link_model.item(0, 0)))
link_model.remove_item(link_model.indexFromItem(link_model.item(1, 0)))
link_model.compact()
assert link_model.item(0, 0).data(role=NODE_ROLE) == read_2
assert link_model.item(1, 0).data(role=NODE_ROLE) == read_4
check(2, 1)
def test_on_property_changed(self, qtbot, link_model, nodes):
composite = nodes['cpm']
read = nodes['read']
read_2 = nodes['read_2']
read_3 = nodes['read_3']
read_4 = nodes['read_4']
# Read 2->height and cpm->Read->number are silent dependend on Read->number
setup_silent(link_model, nodes)
# Put every linked property to 0 to make sure we are not lucky when checking if the links
# work
read_3.model['number'] = 0
read_4.model['number'] = 0
composite.model['Pad']['width'] = 0
link_model.add_item(read_3, read_3.model, 'number', 0, 1)
link_model.add_item(read_4, read_4.model, 'number', 1, 0)
link_model.add_item(composite, composite.model['Pad'], 'width', 1, 1)
read.model['number'] = 100
read.model.property_changed.emit(read.model, 'number', read.model['number'])
# Row 0
# Direct link
assert read_3.model['number'] == read.model['number']
# Silent links
assert read_2.model['height'] == read.model['number']
assert composite.model['Read']['number'] == read.model['number']
# Row 1
read_4.model['number'] = 100
read_4.model.property_changed.emit(read_4.model, 'number', read_4.model['number'])
assert composite.model['Pad']['width'] == read_4.model['number']
tofu-0.12.0/tofu/tests/test_flow_propertylinkswidget.py 0000664 0000000 0000000 00000003225 14237137211 0023462 0 ustar 00root root 0000000 0000000 import pytest
from PyQt5.QtCore import Qt, QItemSelectionModel
from tofu.flow.propertylinkswidget import NodesView, PropertyLinks, PropertyLinksView
from tofu.tests.flow_util import get_index_from_treemodel, populate_link_model
@pytest.fixture(scope='function')
def node_view(node_model):
view = NodesView()
view.setHeaderHidden(True)
view.setAlternatingRowColors(True)
view.setDragEnabled(True)
view.setAcceptDrops(False)
view.setModel(node_model)
return view
@pytest.fixture(scope='function')
def link_view():
return PropertyLinksView()
@pytest.fixture(scope='function')
def link_widget(node_model):
return PropertyLinks()
def test_property_links_view_delete_key(qtbot, link_model, link_view, nodes):
qtbot.addWidget(link_view)
link_view.setModel(link_model)
populate_link_model(link_model, nodes)
link_view.selectColumn(0)
qtbot.keyPress(link_view, Qt.Key_Delete)
assert link_model.columnCount() == 2
def test_node_view_get_drag_index(qtbot, node_view, nodes):
node_model = node_view.model()
read = nodes['read']
node_model.add_node(read)
sm = node_view.selectionModel()
# Nothing selected
assert node_view.get_drag_index() is None
# Node selection must yield nothing
index = node_model.indexFromItem(node_model.item(0, 0))
sm.select(index, QItemSelectionModel.Select)
assert node_view.get_drag_index() is None
sm.clear()
# Property selection must yield an index which can be dragged
index = get_index_from_treemodel(node_model, 0, 'number')
sm.select(index, QItemSelectionModel.Select)
assert node_view.get_drag_index() is not None
sm.clear()
tofu-0.12.0/tofu/tests/test_flow_runslider.py 0000664 0000000 0000000 00000014710 14237137211 0021341 0 ustar 00root root 0000000 0000000 import pytest
from tofu.flow.runslider import RunSlider, RunSliderError
@pytest.fixture(scope='function')
def runslider(qtbot, scene):
slider = RunSlider()
node = scene.create_node(scene.registry.create('filter'))
slider.setup(node.model._view._properties['cutoff'].view_item)
qtbot.addWidget(slider)
return slider
class TestRunSlider:
def test_setup(self, qtbot, runslider):
assert not runslider.setup(runslider.view_item)
bottom = runslider.view_item.widget.validator().bottom()
top = runslider.view_item.widget.validator().top()
assert runslider.type == float
assert runslider.real_minimum == bottom
assert runslider.real_maximum == top
assert float(runslider.min_edit.text()) == bottom
assert float(runslider.max_edit.text()) == top
assert float(runslider.current_edit.text()) == runslider.view_item.get()
assert runslider.slider.value() / 100 + runslider.real_minimum == runslider.view_item.get()
assert runslider.isEnabled()
def test_reset(self, qtbot, runslider):
runslider.reset()
assert runslider.view_item is None
assert runslider.type is None
assert runslider.real_minimum == 0
assert runslider.real_maximum == 100
assert runslider.real_span == 100
assert runslider.min_edit.text() == ''
assert runslider.max_edit.text() == ''
assert runslider.current_edit.text() == ''
assert not runslider.isEnabled()
def test_empty(self, qtbot, runslider):
runslider.reset()
runslider.on_min_edit_editing_finished()
runslider.on_max_edit_editing_finished()
runslider.on_current_edit_editing_finished()
def test_min_edit_changed(self, qtbot, runslider):
top = runslider.view_item.widget.validator().top()
with pytest.raises(RunSliderError):
runslider.min_edit.setText('asdf')
runslider.on_min_edit_editing_finished()
with pytest.raises(RunSliderError):
runslider.min_edit.setText(str(top + 1))
runslider.on_min_edit_editing_finished()
# Current value lower than new minimum, must be updated
value = runslider.get_real_value() + 0.1
runslider.min_edit.setText(str(value))
runslider.on_min_edit_editing_finished()
assert value == runslider.get_real_value()
def test_max_edit_changed(self, qtbot, runslider):
bottom = runslider.view_item.widget.validator().bottom()
with pytest.raises(RunSliderError):
runslider.max_edit.setText('asdf')
runslider.on_max_edit_editing_finished()
with pytest.raises(RunSliderError):
runslider.max_edit.setText(str(bottom - 1))
runslider.on_max_edit_editing_finished()
# Current value greater than new maximum, must be updated
value = runslider.get_real_value() - 0.1
runslider.max_edit.setText(str(value))
runslider.on_max_edit_editing_finished()
assert value == runslider.get_real_value()
def test_current_edit_changed(self, qtbot, runslider):
self.value_changed_value = None
def on_value_changed(value):
self.value_changed_value = value
# Nothing changed, no update triggered
runslider.on_current_edit_editing_finished()
assert self.value_changed_value is None
runslider.value_changed.connect(on_value_changed)
current = runslider.get_real_value() + 0.1
runslider.current_edit.setText(str(current))
runslider.on_current_edit_editing_finished()
assert runslider.get_real_value() == current
runslider.current_edit.setText('asf')
with pytest.raises(RunSliderError):
runslider.on_current_edit_editing_finished()
def test_int(self, qtbot, runslider, scene):
node = scene.create_node(scene.registry.create('read'))
runslider.setup(node.model._view._properties['y'].view_item)
assert runslider.type == int
runslider.min_edit.setText('1')
runslider.on_min_edit_editing_finished()
runslider.max_edit.setText('10')
runslider.on_max_edit_editing_finished()
runslider.slider.setValue(50)
assert type(runslider.get_real_value()) == int
assert runslider.get_real_value() == 5
runslider.current_edit.setText('7')
runslider.on_current_edit_editing_finished()
assert type(runslider.get_real_value()) == int
assert runslider.get_real_value() == 7
# Maximum smaller than current -> update current
runslider.max_edit.setText('5')
runslider.on_max_edit_editing_finished()
assert type(runslider.get_real_value()) == int
assert runslider.get_real_value() == 5
# Minimum greater than current -> update current
runslider.max_edit.setText('10')
runslider.on_max_edit_editing_finished()
runslider.min_edit.setText('8')
runslider.on_min_edit_editing_finished()
assert type(runslider.get_real_value()) == int
assert runslider.get_real_value() == 8
def test_range(self, qtbot, scene):
runslider = RunSlider()
qtbot.addWidget(runslider)
node = scene.create_node(scene.registry.create('general_backproject'))
node.model['center-position-x'] = [1, 2, 3]
assert not runslider.setup(node.model._view._properties['center-position-x'].view_item)
assert runslider.view_item is None
node.model['center-position-x'] = [1]
assert runslider.setup(node.model._view._properties['center-position-x'].view_item)
assert runslider.view_item == node.model._view._properties['center-position-x'].view_item
assert type(runslider.get_real_value()) == float
assert runslider.get_real_value() == 1
runslider.current_edit.setText('1.1')
runslider.on_current_edit_editing_finished()
assert node.model['center-position-x'] == [runslider.get_real_value()]
def test_links(self, qtbot, link_model, nodes):
runslider = RunSlider()
qtbot.addWidget(runslider)
read = nodes['read']
read_2 = nodes['read_2']
runslider.setup(read.model._view._properties['number'].view_item)
link_model.add_item(read, read.model, 'number', -1, -1)
link_model.add_item(read_2, read_2.model, 'number', 0, -1)
runslider.current_edit.setText('123')
runslider.on_current_edit_editing_finished()
assert read_2.model['number'] == 123
tofu-0.12.0/tofu/tests/test_flow_scene.py 0000664 0000000 0000000 00000041406 14237137211 0020431 0 ustar 00root root 0000000 0000000 import pytest
from PyQt5.QtCore import QModelIndex
from PyQt5.QtWidgets import QInputDialog
from qtpynodeeditor import FlowView
from tofu.flow.models import BaseCompositeModel, UfoModelError, UfoReadModel
from tofu.flow.util import FlowError, MODEL_ROLE, NODE_ROLE, PROPERTY_ROLE
from tofu.tests.flow_util import add_nodes_to_scene
class TestScene:
def test_create_node(self, qtbot, scene):
def check_node(node, gt_caption):
# Node must be in the scene
assert node in scene.nodes.values()
# Caption must be unique
assert node.model.caption == gt_caption
# Node must be in the nodes model
item = scene.node_model.findItems(node.model.caption)[0]
assert item.data(role=MODEL_ROLE) == node.model
nodes = add_nodes_to_scene(scene, model_names=['read', 'read'])
for (node, gt_caption) in zip(nodes, ['Read', 'Read 2']):
check_node(node, gt_caption)
# Property links must be set up by composites
def check_link(model, prop_name):
assert scene.property_links_model.find_items((model, prop_name),
(MODEL_ROLE, PROPERTY_ROLE))
scene.clear_scene()
node = add_nodes_to_scene(scene, model_names=['CFlatFieldCorrect'])[0]
for link in node.model._links:
model_name, prop_name = link[0]
other_name, other_prop_name = link[1]
model = node.model[model_name]
other = node.model[other_name]
model[prop_name] = 0
qtbot.addWidget(model.embedded_widget())
qtbot.addWidget(other.embedded_widget())
qtbot.keyClick(model._view._properties[prop_name].view_item.widget, '1')
# Other model's property has to be updated if the property links have been set up
# correctly
assert node.model[other_name][other_prop_name] == node.model[model_name][prop_name]
def test_setstate(self, qtbot, scene):
# Make sure there are some links by adding FFC
nodes = add_nodes_to_scene(scene, model_names=['CFlatFieldCorrect', 'average'])
# Create a connection
scene.create_connection(nodes[0]['output'][0], nodes[1]['input'][0])
state = scene.__getstate__()
scene.clear_scene()
scene.__setstate__(state)
assert scene.__getstate__() == state
def test_getstate(self, qtbot, scene):
# Make sure there are some links by adding FFC
nodes = add_nodes_to_scene(scene, model_names=['CFlatFieldCorrect', 'average'])
# Create a connection
scene.create_connection(nodes[0]['output'][0], nodes[1]['input'][0])
state = scene.__getstate__()
# Nodes
ids = [record['id'] for record in state['nodes']]
assert nodes[0].id in ids
assert nodes[1].id in ids
# Connections
assert len(state['connections']) == 1
conn = state['connections'][0]
assert conn['in_id'] == nodes[1].id
assert conn['out_id'] == nodes[0].id
# Property links
assert state['property-links'] == scene.property_links_model.save()
def test_restore_node(self, qtbot, monkeypatch, scene):
add_nodes_to_scene(scene)
old_node = list(scene.nodes.values())[0]
state = old_node.__getstate__()
scene.remove_node(old_node)
new_node = scene.restore_node(state)
# Don't test the nodes themselves because the models won't match
assert old_node.id == new_node.id
# num-inputs
monkeypatch.setattr(QInputDialog, 'getInt', lambda *args, **kwargs: (2, True))
node = add_nodes_to_scene(scene, model_names=['retrieve_phase'])[0]
state = node.__getstate__()
scene.remove_node(node)
new_node = scene.restore_node(state)
assert new_node.model.num_ports['input'] == 2
def test_remove_node(self, monkeypatch, qtbot, scene, nodes):
def cleanup():
self.cleanup_called = True
node = add_nodes_to_scene(scene)[0]
self.cleanup_called = False
node.model.cleanup = cleanup
scene.property_links_model.add_item(node, node.model, node.model.properties[0],
0, 0, QModelIndex())
scene.remove_node(list(scene.nodes.values())[0])
# Scene, node model and property links model must be empty
assert len(scene.nodes) == 0
assert scene.node_model.rowCount() == 0
assert scene.property_links_model.rowCount() == 0
assert self.cleanup_called
# Composite removal
monkeypatch.setattr(QInputDialog, "getText", lambda *args: ('cpm', True))
nodes = add_nodes_to_scene(scene, model_names=['pad', 'crop'])
for node in nodes:
node.graphics_object.setSelected(True)
node = scene.create_composite()
state = node.__getstate__()
scene.remove_node(node)
# _composite_nodes must be updated
assert scene._composite_nodes == {}
# Simulate non-interactive composite creation, i.e. not combining existing nodes into a
# composite node. When removing such node, _composite_nodes must not raise a KeyError
node = scene.restore_node(state)
scene.remove_node(node)
def test_is_selected_one_composite(self, qtbot, scene, monkeypatch):
# Circumvent the input dialog
monkeypatch.setattr(QInputDialog, "getText", lambda *args: ('cpm', True))
nodes = add_nodes_to_scene(scene, model_names=['read', 'read'])
for node in nodes:
node.graphics_object.setSelected(True)
# Simple nodes
assert not scene.is_selected_one_composite()
node = scene.create_composite()
# Composite
assert scene.is_selected_one_composite()
node.graphics_object.setSelected(False)
# Nothing selected
assert not scene.is_selected_one_composite()
# Composite and other selected
add_nodes_to_scene(scene, ['null'])
for node in scene.nodes.values():
node.graphics_object.setSelected(True)
assert not scene.is_selected_one_composite()
def test_skip_nodes(self, qtbot, scene):
nodes = add_nodes_to_scene(scene, model_names=['read', 'pad', 'crop', 'null'])
read, pad, crop, null = nodes
scene.create_connection(read['output'][0], pad['input'][0])
scene.create_connection(pad['output'][0], crop['input'][0])
scene.create_connection(crop['output'][0], null['input'][0])
read.graphics_object.setSelected(True)
# Only fully connected nodes can be disabled
with pytest.raises(FlowError):
scene.skip_nodes()
read.graphics_object.setSelected(False)
null.graphics_object.setSelected(True)
with pytest.raises(FlowError):
scene.skip_nodes()
null.graphics_object.setSelected(False)
pad.graphics_object.setSelected(True)
scene.skip_nodes()
assert pad.model.skip
scene.skip_nodes()
assert not pad.model.skip
# Deprecation warning coming from imageio
@pytest.mark.filterwarnings('ignore::DeprecationWarning')
def test_auto_fill(self, qtbot, scene):
add_nodes_to_scene(scene)
with pytest.raises(UfoModelError):
scene.auto_fill()
def test_copy_node(self, qtbot, scene):
nodes = add_nodes_to_scene(scene, model_names=['read', 'null'])
scene.create_connection(nodes[0]['output'][0], nodes[1]['input'][0])
for node in nodes:
node.graphics_object.setSelected(True)
scene.copy_nodes()
assert len(scene.nodes) == 4
# Choose the newly created connections
if scene.connections[0].valid_ports['input'].node in nodes:
ports = scene.connections[1].valid_ports
else:
ports = scene.connections[0].valid_ports
# The fact that the connections are there means the nodes are there as well, so we don't
# need to test that
assert ports['input'].node.model.name == 'null'
assert ports['output'].node.model.name == 'read'
def test_create_composite(self, monkeypatch, qtbot, scene):
monkeypatch.setattr(QInputDialog, "getText", lambda *args: ('cpm', True))
plm = scene.property_links_model
nodes = add_nodes_to_scene(scene, model_names=['read', 'read'])
plm.add_item(nodes[0], nodes[0].model, nodes[0].model.properties[0],
-1, -1, QModelIndex())
for (i, node) in enumerate(nodes):
node.graphics_object.setSelected(True)
node = scene.create_composite()
assert node.model._links == [[[nodes[0].model.caption, nodes[0].model.properties[0]]]]
with pytest.raises(FlowError):
# Can't create a composite with the same name
scene.create_composite()
assert len(scene.nodes) == 1
assert list(scene.nodes.values())[0] == node
assert isinstance(node.model, BaseCompositeModel)
assert nodes[0] not in scene.nodes.values()
assert nodes[1] not in scene.nodes.values()
# Property links model
assert plm.item(0, 0).data(role=NODE_ROLE) == node
# Simulate non-interactive composite creation, i.e. not combining existing nodes into a
# composite node. In this case it can't be possible to create a new composite node with the
# same name as has already been registered.
state = node.__getstate__()
scene.remove_node(node)
node = scene.restore_node(state)
node.graphics_object.setSelected(True)
with pytest.raises(FlowError):
scene.create_composite()
# Add outer composite with a composite and another simple model inside and set the property
# links between the inner composite and inner simple. They must be present in the newly
# craeted outer composite.
average = add_nodes_to_scene(scene, model_names=['average'])[0]
average.graphics_object.setSelected(True)
monkeypatch.setattr(QInputDialog, "getText", lambda *args: ('outer', True))
plm.add_item(node, node.model['Read'], 'number', 0, 0)
plm.add_item(node, node.model['Read 2'], 'number', 0, 1)
plm.add_item(average, average.model, 'number', 0, 2)
outer = scene.create_composite()
assert len(outer.model._links) == 1
row = outer.model._links[0]
assert [node.model.caption, node.model['Read'].caption, 'number'] in row
assert [node.model.caption, node.model['Read 2'].caption, 'number'] in row
assert [average.model.caption, 'number'] in row
assert [node.model.caption, node.model['Read'].caption, 'height'] not in row
def test_on_node_double_clicked(self, qtbot, scene, monkeypatch):
def double_clicked(*args):
self.did_click = True
self.did_click = False
monkeypatch.setattr(UfoReadModel, "double_clicked", double_clicked)
node = add_nodes_to_scene(scene)[0]
# We need a view for double clicks
_ = FlowView(scene)
scene.on_node_double_clicked(node)
assert self.did_click
def test_expand_composite(self, qtbot, scene, monkeypatch):
monkeypatch.setattr(QInputDialog, "getText", lambda *args: ('cpm', True))
plm = scene.property_links_model
nodes = add_nodes_to_scene(scene, model_names=['read', 'null'])
name_to_caption = {'read': 'Read', 'null': 'Null'}
for node in nodes:
node.graphics_object.setSelected(True)
node = scene.create_composite()
path = node.model.get_leaf_paths()[0]
plm.add_item(node, path[-1], path[-1].properties[0], -1, -1, QModelIndex())
scene.expand_composite(node)
assert plm.item(0, 0).data(role=MODEL_ROLE).name == path[-1].name
assert plm.item(0, 0).data(role=NODE_ROLE) in [node for node in scene.selected_nodes()]
# Captions are the same
for node in scene.nodes.values():
assert node.model.caption == name_to_caption[node.model.name]
# New caption if there is a node with the original one
# Selection stays, just re-use the expanded nodes
monkeypatch.setattr(QInputDialog, "getText", lambda *args: ('cpm_2', True))
node = scene.create_composite()
other_read_node = add_nodes_to_scene(scene, model_names=['read'])[0]
scene.expand_composite(node)
for node in scene.nodes.values():
if node.model.name == 'read':
if node == other_read_node:
assert node.model.caption == 'Read'
else:
assert node.model.caption == 'Read 2'
def test_is_fully_connected(self, qtbot, scene):
nodes = add_nodes_to_scene(scene, model_names=['read', 'pad', 'crop', 'null'])
read, pad, crop, null = nodes
scene.create_connection(read['output'][0], pad['input'][0])
scene.create_connection(pad['output'][0], crop['input'][0])
scene.create_connection(crop['output'][0], null['input'][0])
assert scene.is_fully_connected()
scene.remove_node(read)
assert not scene.is_fully_connected()
def test_get_simple_node_graphs(self, qtbot, scene, monkeypatch):
def connect(read, pad, crop, null):
scene.create_connection(read['output'][0], pad['input'][0])
scene.create_connection(pad['output'][0], crop['input'][0])
scene.create_connection(crop['output'][0], null['input'][0])
monkeypatch.setattr(QInputDialog, "getText", lambda *args: ('cpm', True))
nodes = add_nodes_to_scene(scene, model_names=2 * ['read', 'pad', 'crop', 'null'])
read, pad, crop, null = nodes[:4]
read_2, pad_2, crop_2, null_2 = nodes[4:]
connect(read, pad, crop, null)
connect(read_2, pad_2, crop_2, null_2)
connections = [('Read', 'Pad'), ('Pad', 'Crop'), ('Crop', 'Null'),
('Read 2', 'Pad 2'), ('Pad 2', 'Crop 2'), ('Crop 2', 'Null 2')]
graphs = scene.get_simple_node_graphs()
assert len(graphs) == 2
num_visited = 0
for graph in graphs:
for (src, dst, index) in graph.edges:
assert (src.caption, dst.caption) in connections
num_visited += 1
assert num_visited == len(connections)
# Create first composite
for node in nodes:
if node.model.name in ['pad', 'crop']:
node.graphics_object.setSelected(True)
scene.create_composite()
# Create a second composite which will cause the scene to have multiple edges between two
# nodes (the first composite's outputs and second's inputs)
scene.clearSelection()
monkeypatch.setattr(QInputDialog, "getText", lambda *args: ('cpm_2', True))
null.graphics_object.setSelected(True)
null_2.graphics_object.setSelected(True)
scene.create_composite()
# Composite must not affect simple graphs, especially the multiple edges cannot be present
# anymore
graphs = scene.get_simple_node_graphs()
assert len(graphs) == 2
num_visited = 0
for graph in graphs:
for (src, dst, index) in graph.edges:
assert (src.caption, dst.caption) in connections
num_visited += 1
assert num_visited == len(connections)
add_nodes_to_scene(scene)
assert len(scene.get_simple_node_graphs()) == 3
# Test disabling nodes
scene.clear_scene()
nodes = add_nodes_to_scene(scene, model_names=['read', 'pad', 'crop', 'null'])
read, pad, crop, null = nodes
connect(read, pad, crop, null)
# Disable padding, the generated flow must be read -> crop -> null
pad.graphics_object.setSelected(True)
scene.skip_nodes()
graph = scene.get_simple_node_graphs()[0]
assert len(graph.edges) == 2
edges = list(graph.edges)
src, dst = edges[0][:-1]
assert dst == crop.model
src, dst = edges[1][:-1]
assert src == crop.model
assert dst == null.model
def test_set_enabled(self, qtbot, scene):
def check(enabled):
assert scene.allow_node_creation == enabled
assert scene.allow_node_deletion == enabled
for node in scene.nodes.values():
assert node._graphics_obj.isEnabled() == enabled
for conn in scene.connections:
assert conn._graphics_object.isEnabled() == enabled
nodes = add_nodes_to_scene(scene, model_names=['CFlatFieldCorrect', 'average'])
nodes[0].graphics_object.setSelected(True)
# Create a connection
scene.create_connection(nodes[0]['output'][0], nodes[1]['input'][0])
scene.set_enabled(False)
check(False)
scene.set_enabled(True)
check(True)
assert nodes[0].graphics_object.isSelected()
tofu-0.12.0/tofu/tests/test_flow_util.py 0000664 0000000 0000000 00000003667 14237137211 0020320 0 ustar 00root root 0000000 0000000 import pytest
from PyQt5.QtWidgets import QInputDialog
from tofu.flow.util import CompositeConnection, get_config_key, saved_kwargs
from tofu.flow.main import get_filled_registry
def test_get_config_key():
# Existing key
assert 'z' in get_config_key('models', 'general-backproject', 'hidden-properties')
# Non-existent key
assert get_config_key('foobarbaz') is None
assert get_config_key('foobarbaz', default=1) == 1
def test_saved_kwargs(qtbot, monkeypatch, scene):
registry = get_filled_registry()
name = 'retrieve_phase'
# No num-inputs info
monkeypatch.setattr(QInputDialog, 'getInt', lambda *args, **kwargs: (2, True))
state = {'name': name}
model = registry.create(name)
assert model.num_ports['input'] == 2
# num-inputs specified
state = {'name': name, 'num-inputs': 3}
with saved_kwargs(registry, state):
model = registry.create(name)
assert model.num_ports['input'] == 3
assert 'num_inputs' not in registry.registered_model_creators()[state['name']][1]
class TestCompositeConnection:
def test_init(self):
# Identical source and tartet -> exception
with pytest.raises(ValueError):
CompositeConnection('a', 0, 'a', 0)
# OK, must pass
CompositeConnection('a', 0, 'b', 0)
def test_contains(self):
conn = CompositeConnection('a', 0, 'b', 0)
assert conn.contains('a', 'output', 0)
assert not conn.contains('a', 'output', 1)
assert not conn.contains('a', 'input', 0)
assert not conn.contains('a', 'input', 1)
assert conn.contains('b', 'input', 0)
assert not conn.contains('b', 'input', 1)
assert not conn.contains('b', 'output', 0)
assert not conn.contains('b', 'output', 1)
assert not conn.contains('foo', 'input', 14)
def test_save(self):
conn = CompositeConnection('a', 0, 'b', 0)
assert conn.save() == ['a', 0, 'b', 0]
tofu-0.12.0/tofu/tests/test_flow_viewer.py 0000664 0000000 0000000 00000027234 14237137211 0020640 0 ustar 00root root 0000000 0000000 import pytest
import numpy as np
from PyQt5.QtGui import QValidator
from tofu.flow.viewer import ImageLabel, ImageViewingError, ScreenImage, ImageViewer
@pytest.fixture(scope='function')
def screen_image():
image = np.arange(256, dtype=np.float32).reshape(16, 16)
return ScreenImage(image=image)
@pytest.fixture(scope='function')
def viewer(qtbot):
viewer = ImageViewer()
viewer.images = np.ones((10, 16, 16))
viewer.popup()
qtbot.addWidget(viewer._pg_window)
return viewer
class TestScreenImage:
def test_image_setter(self):
screen_image = ScreenImage()
assert screen_image.image is None
screen_image.image = np.random.normal(size=(8, 8))
assert screen_image.minimum is not None
assert screen_image.maximum is not None
assert screen_image.black_point is not None
assert screen_image.white_point is not None
def test_black_point_setter(self, screen_image):
screen_image.black_point = 100
assert screen_image.black_point == 100
screen_image.white_point = 150
# Black point cannot be greater than white point
with pytest.raises(ImageViewingError):
screen_image.black_point = 200
def test_white_point_setter(self, screen_image):
screen_image.white_point = 100
assert screen_image.white_point == 100
screen_image.black_point = 50
# White point cannot be smaller than black point
with pytest.raises(ImageViewingError):
screen_image.white_point = 0
def test_reset(self, screen_image):
screen_image.reset()
# We can test with ==, the data types are the same so the extrema must be exactly the same
assert screen_image.minimum == 0
assert screen_image.maximum == 255
assert screen_image.black_point == 0
assert screen_image.white_point == 255
# Going out of the original gray value range must not cause exception on reset
screen_image.black_point = -100
screen_image.white_point = -50
screen_image.reset()
def test_auto_levels(self, screen_image):
screen_image.auto_levels()
# nonsense values must pass as well
screen_image.auto_levels(percentile=200.0)
screen_image.auto_levels(percentile=-200.0)
def test_set_black_point_normalized(self, screen_image):
screen_image.set_black_point_normalized(100)
assert screen_image.black_point == 100
screen_image.set_white_point_normalized(150)
# Black point cannot be greater than white point
with pytest.raises(ImageViewingError):
screen_image.set_black_point_normalized(200)
def test_set_white_point_normalized(self, screen_image):
screen_image.set_white_point_normalized(100)
assert screen_image.white_point == 100
screen_image.set_black_point_normalized(50)
# White point cannot be smaller than black point
with pytest.raises(ImageViewingError):
screen_image.set_white_point_normalized(0)
def test_convert_normalized_value_to_native(self, screen_image):
assert screen_image.convert_normalized_value_to_native(128) == 128.
with pytest.raises(ImageViewingError):
screen_image.convert_normalized_value_to_native(-500)
with pytest.raises(ImageViewingError):
screen_image.convert_normalized_value_to_native(500)
def test_convert_native_value_to_normalized(self, screen_image):
assert screen_image.convert_native_value_to_normalized(128) == 128.
with pytest.raises(ImageViewingError):
screen_image.convert_native_value_to_normalized(-500)
with pytest.raises(ImageViewingError):
screen_image.convert_native_value_to_normalized(500)
# One gray value must not cause division by zero erro
screen_image.image = np.ones((4, 4))
screen_image.reset()
screen_image.convert_native_value_to_normalized(1)
def test_get_pixmap(self, qtbot, screen_image):
# Empty image must raise an exception
with pytest.raises(ImageViewingError):
ScreenImage().get_pixmap()
# Downsampling
pixmap = screen_image.get_pixmap()
assert (pixmap.height(), pixmap.width()) == screen_image.image.shape
pixmap = screen_image.get_pixmap(downsampling=2)
assert (pixmap.height(), pixmap.width()) == tuple(dim // 2 for dim in
screen_image.image.shape)
# One gray value must not cause division by zero erro
screen_image.image = np.ones((4, 4))
screen_image.reset()
screen_image.get_pixmap()
class TestImageLabel:
def test_updateImage(self, qtbot, screen_image):
label = ImageLabel()
# Empty image must pass
label.updateImage()
label.screen_image = screen_image
label.updateImage()
assert label.pixmap() is not None
def test_resizeEvent(self, qtbot, screen_image):
label = ImageLabel(screen_image)
label.updateImage()
old_size = label.pixmap().size()
# ensure the label will get the resize event
label.show()
label.resize(8, 8)
new_size = label.pixmap().size()
assert new_size != old_size
class TestImageViewer:
def test_images_setter(self, qtbot):
viewer = ImageViewer()
viewer.images = np.zeros((16, 16))
assert viewer.images.ndim == 3
assert viewer.slider.isHidden()
assert float(viewer.min_slider_edit.text()) == 0
assert float(viewer.max_slider_edit.text()) == 0
viewer.images = np.ones((5, 16, 16))
assert viewer.images.ndim == 3
assert not viewer.slider.isHidden()
assert viewer.slider.minimum() == 0
assert viewer.slider.maximum() == viewer.images.shape[0] - 1
assert float(viewer.min_slider_edit.text()) == 1
assert float(viewer.max_slider_edit.text()) == 1
# Test viewer and popup window equality
viewer.popup()
qtbot.addWidget(viewer._pg_window)
np.testing.assert_almost_equal(viewer.images, 1)
np.testing.assert_almost_equal(viewer._pg_window.image, 1)
# 3D
viewer.images = np.ones((5, 16, 16)) * 5
np.testing.assert_almost_equal(viewer.images, 5)
np.testing.assert_almost_equal(viewer._pg_window.image, 5)
# 2D
viewer.images = np.ones((16, 16)) * 3
np.testing.assert_almost_equal(viewer.images, 3)
np.testing.assert_almost_equal(viewer._pg_window.image, 3)
# validators
viewer.images = 10 + np.arange(200 * 8 ** 2).reshape(200, 8, 8)
validator = viewer.slider_edit.validator()
assert validator.validate('199', 0)[0] == QValidator.Acceptable
assert validator.validate('2000', 0)[0] == QValidator.Invalid
assert viewer.min_slider_edit.validator().bottom() == viewer.images[0].min()
assert viewer.min_slider_edit.validator().top() == viewer.images[0].max()
viewer._pg_window.close()
def test_append(self, qtbot):
viewer = ImageViewer()
# Append to empty
viewer.append(np.zeros((4, 4)))
assert viewer.images.ndim == 3
assert viewer.images.shape == (1, 4, 4)
# Append 2D
viewer.append(np.zeros((4, 4)))
assert viewer.images.shape == (2, 4, 4)
# Append 3D
viewer.append(np.zeros((3, 4, 4)))
assert viewer.images.shape == (5, 4, 4)
# Append wrong shape
with pytest.raises(ImageViewingError):
viewer.append(np.zeros((3, 2, 2)))
def test_set_enabled_adjustments(self, qtbot):
viewer = ImageViewer()
def assert_all(value):
viewer.set_enabled_adjustments(value)
assert viewer.slider.isEnabled() == value
assert viewer.slider_edit.isEnabled() == value
assert viewer.min_slider.isEnabled() == value
assert viewer.min_slider_edit.isEnabled() == value
assert viewer.max_slider.isEnabled() == value
assert viewer.max_slider_edit.isEnabled() == value
assert_all(True)
assert_all(False)
def test_reset_clim(self, viewer):
image = np.arange(16 ** 2).reshape(16, 16)
viewer.images = image
viewer.append(image * 2)
viewer.slider.setValue(1)
viewer.reset_clim(auto=False)
si = viewer.screen_image
min_converted = si.convert_native_value_to_normalized(si.black_point)
max_converted = si.convert_native_value_to_normalized(si.white_point)
assert viewer.screen_image.maximum == pytest.approx(510)
assert viewer.min_slider.value() == pytest.approx(min_converted)
assert viewer.max_slider.value() == pytest.approx(max_converted)
assert float(viewer.min_slider_edit.text()) == pytest.approx(si.black_point)
assert float(viewer.max_slider_edit.text()) == pytest.approx(si.white_point)
# Pop up window must be updated
assert viewer._pg_window.getLevels() == pytest.approx((si.black_point, si.white_point))
viewer._pg_window.close()
def test_on_slider_value_changed(self, viewer):
viewer.slider.setValue(5)
assert viewer._pg_window.currentIndex == 5
assert viewer.slider_edit.text() == '5'
viewer._pg_window.close()
def test_on_slider_edit_return_pressed(self, viewer):
viewer.slider_edit.setText('5')
viewer.slider_edit.returnPressed.emit()
assert viewer.slider.value() == 5
assert viewer._pg_window.currentIndex == 5
viewer._pg_window.close()
def test_on_min_slider_edit_return_pressed(self, viewer):
viewer.images = np.arange(256).reshape(16, 16)
viewer.min_slider_edit.setText('100')
viewer.min_slider_edit.returnPressed.emit()
assert viewer.screen_image.black_point == pytest.approx(100)
assert viewer.min_slider.value() == pytest.approx(100)
assert viewer._pg_window.getLevels()[0] == pytest.approx(100)
viewer._pg_window.close()
def test_on_max_slider_edit_return_pressed(self, viewer):
viewer.images = np.arange(256).reshape(16, 16)
viewer.max_slider_edit.setText('100')
viewer.max_slider_edit.returnPressed.emit()
assert viewer.screen_image.white_point == pytest.approx(100)
assert viewer.max_slider.value() == pytest.approx(100)
assert viewer._pg_window.getLevels()[1] == pytest.approx(100)
viewer._pg_window.close()
def test_on_min_slider_value_changed(self, viewer):
viewer.images = np.arange(256).reshape(16, 16)
viewer.min_slider.valueChanged.emit(100)
assert viewer.screen_image.black_point == pytest.approx(100)
assert float(viewer.min_slider_edit.text()) == pytest.approx(100)
assert viewer._pg_window.getLevels()[0] == pytest.approx(100)
viewer._pg_window.close()
def test_on_max_slider_value_changed(self, viewer):
viewer.images = np.arange(256).reshape(16, 16)
viewer.max_slider.valueChanged.emit(100)
assert viewer.screen_image.white_point == pytest.approx(100)
assert float(viewer.max_slider_edit.text()) == pytest.approx(100)
assert viewer._pg_window.getLevels()[1] == pytest.approx(100)
viewer._pg_window.close()
def test_popup(self, qtbot, viewer):
# Close and another popup call must show the widget
viewer._pg_window.close()
viewer.popup()
assert viewer._pg_window.isVisible()
# 2D must work
other = ImageViewer()
other.images = np.ones((4, 4))
other.popup()
qtbot.addWidget(other._pg_window)
assert other._pg_window is not None
viewer._pg_window.close()
other._pg_window.close()
tofu-0.12.0/tofu/util.py 0000664 0000000 0000000 00000024606 14237137211 0015064 0 ustar 00root root 0000000 0000000 """Various utility functions."""
import argparse
import glob
import logging
import math
import os
from collections import OrderedDict
LOG = logging.getLogger(__name__)
def range_list(value):
"""
Split *value* separated by ':' into int triple, filling missing values with 1s.
"""
def check(region):
if region[0] >= region[1]:
raise argparse.ArgumentTypeError("{} must be less than {}".format(region[0], region[1]))
lst = [int(x) for x in value.split(':')]
if len(lst) == 1:
frm = lst[0]
return (frm, frm + 1, 1)
if len(lst) == 2:
check(lst)
return (lst[0], lst[1], 1)
if len(lst) == 3:
check(lst)
return (lst[0], lst[1], lst[2])
raise argparse.ArgumentTypeError("Cannot parse {}".format(value))
def make_subargs(args, subargs):
"""Return an argparse.Namespace consisting of arguments from *args* which are listed in the
*subargs* list."""
namespace = argparse.Namespace()
for subarg in subargs:
setattr(namespace, subarg, getattr(args, subarg))
return namespace
def set_node_props(node, args):
"""Set up *node*'s properties to *args* which is a dictionary of values."""
for name in dir(node.props):
if not name.startswith('_') and hasattr(args, name):
value = getattr(args, name)
if value is not None:
LOG.debug("Setting {}:{} to {}".format(node.get_plugin_name(), name, value))
node.set_property(name, getattr(args, name))
def get_filenames(path):
"""
Get all filenams from *path*, which could be a directory or a pattern for
matching files in a directory.
"""
return sorted(glob.glob(os.path.join(path, '*') if os.path.isdir(path) else path))
def setup_read_task(task, path, args):
"""Set up *task* and take care of handling file types correctly."""
task.props.path = path
fnames = get_filenames(path)
if fnames and fnames[0].endswith('.raw'):
if not args.width or not args.height:
raise RuntimeError("Raw files require --width, --height and --bitdepth arguments.")
task.props.raw_width = args.width
task.props.raw_height = args.height
task.props.raw_bitdepth = args.bitdepth
def restrict_value(limits, dtype=float):
"""Convert value to *dtype* and make sure it is within *limits* (included) specified as tuple
(min, max). If one of the tuple values is None it is ignored."""
def check(value):
result = dtype(value)
if limits[0] is not None and result < limits[0]:
raise argparse.ArgumentTypeError('Value cannot be less than {}'.format(limits[0]))
if limits[1] is not None and result > limits[1]:
raise argparse.ArgumentTypeError('Value cannot be greater than {}'.format(limits[1]))
return result
return check
def convert_filesize(value):
multiplier = 1
conv = OrderedDict((('k', 2 ** 10),
('m', 2 ** 20),
('g', 2 ** 30),
('t', 2 ** 40)))
if not value[-1].isdigit():
if value[-1] not in list(conv.keys()):
raise argparse.ArgumentTypeError('--output-bytes-per-file must either be a ' +
'number or end with {} '.format(list(conv.keys())) +
'to indicate kilo, mega, giga or terabytes')
multiplier = conv[value[-1]]
value = value[:-1]
value = int(float(value) * multiplier)
if value < 0:
raise argparse.ArgumentTypeError('--output-bytes-per-file cannot be less than zero')
return value
def tupleize(num_items=None, conv=float, dtype=tuple):
"""Convert comma-separated string values to a *num-items*-tuple of values converted with
*conv*.
"""
def split_values(value):
"""Convert comma-separated string *value* to a tuple of numbers."""
try:
result = dtype([conv(x) for x in value.split(',')])
except:
raise argparse.ArgumentTypeError('Expect comma-separated tuple')
if num_items and len(result) != num_items:
raise argparse.ArgumentTypeError('Expected {} items'.format(num_items))
return result
return split_values
def next_power_of_two(number):
"""Compute the next power of two of the *number*."""
return 2 ** int(math.ceil(math.log(number, 2)))
def read_image(filename):
"""Read image from file *filename*."""
if filename.lower().endswith('.tif') or filename.lower().endswith('.tiff'):
from tifffile import TiffFile
import numpy as np
with TiffFile(filename) as tif:
return tif.asarray(out='memmap')
elif '.edf' in filename.lower():
import fabio
edf = fabio.edfimage.edfimage()
edf.read(filename)
return edf.data
else:
raise ValueError('Unsupported image format')
def get_image_shape(filename):
"""Determine image shape (numpy order) from file *filename*."""
if filename.lower().endswith('.tif') or filename.lower().endswith('.tiff'):
from tifffile import TiffFile
with TiffFile(filename) as tif:
page = tif.pages[0]
shape = (page.imagelength, page.imagewidth)
if len(tif.pages) > 1:
shape = (len(tif.pages),) + shape
else:
# fabio doesn't seem to be able to read the shape without reading the data
shape = read_image(filename).shape
return shape
def get_first_filename(path):
"""Returns the first valid image filename in *path*."""
if not path:
raise RuntimeError("Path to sinograms or projections not set.")
filenames = get_filenames(path)
if not filenames:
raise RuntimeError("No files found in `{}'".format(path))
return filenames[0]
def determine_shape(args, path=None, store=False):
"""Determine input shape from *args* which means either width and height are specified in args
or try to read the *path* and determine the shape from it. The default path is args.projections,
which is the typical place to find the input. If *store* is True, assign the determined values
if they aren't already present in *args*. Return a tuple (width, height).
"""
width = args.width
height = args.height
if not (width and height):
filename = get_first_filename(path or args.projections)
try:
shape = get_image_shape(filename)
# Now set the width and height if not specified
width = width or shape[-1]
height = height or shape[-2]
except:
LOG.info("Couldn't determine image dimensions from '{}'".format(filename))
if store:
if not args.width:
args.width = width
if not args.height:
args.height = height - args.y
return (width, height)
def get_filtering_padding(width):
"""Get the number of horizontal padded pixels in order to avoid convolution artifacts."""
return next_power_of_two(2 * width) - width
def setup_padding(pad, width, height, mode, crop=None, pad_width=0, pad_height=0):
if not pad_width:
# Default is horizontal padding only
pad_width = get_filtering_padding(width)
pad.props.width = width + pad_width
pad.props.height = height + pad_height
pad.props.x = pad_width // 2
pad.props.y = pad_height // 2
pad.props.addressing_mode = mode
LOG.debug('Padded size: ({}, {})'.format(width + pad_width, height + pad_height))
LOG.debug('Padding mode: {}'.format(mode))
if crop:
# crop to original width after filtering
crop.props.width = width
crop.props.height = height
crop.props.x = pad_width // 2
crop.props.y = pad_height // 2
return (pad_width, pad_height)
def make_region(n, dtype=int):
"""Make region in such a way that in case of odd *n* it is centered around 0. Use *dtype* as
data type.
"""
return (-dtype(n / 2), dtype(n / 2 + n % 2), dtype(1))
def get_reconstructed_cube_shape(x_region, y_region, z_region):
"""Get the shape of the reconstructed cube as (slice width, slice height, num slices)."""
import numpy as np
z_start, z_stop, z_step = z_region
y_start, y_stop, y_step = y_region
x_start, x_stop, x_step = x_region
num_slices = len(np.arange(z_start, z_stop, z_step))
slice_height = len(np.arange(y_start, y_stop, y_step))
slice_width = len(np.arange(x_start, x_stop, x_step))
return slice_width, slice_height, num_slices
def get_reconstruction_regions(params, store=False, dtype=int):
"""Compute reconstruction regions along all three axes, use *dtype* for as data type for x and y
regions, z region is always float.
"""
width, height = determine_shape(params)
if getattr(params, 'transpose_input', False):
# In case down the pipeline there is a transpose task
tmp = width
width = height
height = tmp
if params.x_region[1] == -1:
x_region = make_region(width, dtype=dtype)
else:
x_region = params.x_region
if params.y_region[1] == -1:
y_region = make_region(width, dtype=dtype)
else:
y_region = params.y_region
if params.region[1] == -1:
region = make_region(height, dtype=float)
else:
region = params.region
LOG.info('X region: {}'.format(x_region))
LOG.info('Y region: {}'.format(y_region))
LOG.info('Parameter region: {}'.format(region))
if store:
params.x_region = x_region
params.y_region = y_region
params.region = region
return x_region, y_region, region
def get_scarray_value(scarray, index):
if len(scarray) == 1:
return scarray[0]
return scarray[index]
class Vector(object):
"""A vector based on axis-angle representation."""
def __init__(self, x_angle=0, y_angle=0, z_angle=0, position=None):
import numpy as np
self.position = np.array(position, dtype=np.float) if position is not None else None
self.x_angle = x_angle
self.y_angle = y_angle
self.z_angle = z_angle
def __repr__(self):
return 'Vector(position={}, angles=({}, {}, {}))'.format(self.position,
self.x_angle,
self.y_angle,
self.z_angle)
def __str__(self):
return repr(self)
tofu-0.12.0/tofu/vis/ 0000775 0000000 0000000 00000000000 14237137211 0014326 5 ustar 00root root 0000000 0000000 tofu-0.12.0/tofu/vis/__init__.py 0000664 0000000 0000000 00000000000 14237137211 0016425 0 ustar 00root root 0000000 0000000 tofu-0.12.0/tofu/vis/qt.py 0000664 0000000 0000000 00000013250 14237137211 0015325 0 ustar 00root root 0000000 0000000 import pyqtgraph as pg
import pyqtgraph.opengl as gl
import logging
import numpy as np
import tifffile
from PyQt4 import QtGui, QtCore
LOG = logging.getLogger(__name__)
def read_tiff(filename):
tiff = tifffile.TiffFile(filename)
array = tiff.asarray()
return array.T
def remove_extrema(data):
upper = np.percentile(data, 99)
lower = np.percentile(data, 1)
data[data > upper] = upper
data[data < lower] = lower
return data
def create_volume(data):
gradient = (data - np.roll(data, 1))**2
cmin = gradient.min()
div = gradient.max() - cmin
gradient = (gradient - cmin) / div * 255
volume = np.empty(data.shape + (4, ), dtype=np.ubyte)
volume[..., 0] = data
volume[..., 1] = data
volume[..., 2] = data
volume[..., 3] = gradient
return volume
class ImageViewer(QtGui.QWidget):
"""
Present a sequence of files that can be browsed with a slider.
To get the currently selected position connect to the *slider* attribute's
valueChanged signal.
"""
def __init__(self, filenames, parent=None):
super(ImageViewer, self).__init__(parent)
image_view = pg.ImageView()
image_view.getView().setAspectLocked(True)
self.image_item = image_view.getImageItem()
self.slider = QtGui.QSlider(QtCore.Qt.Horizontal)
self.slider.valueChanged.connect(self.update_image)
self.main_layout = QtGui.QVBoxLayout(self)
self.main_layout.addWidget(image_view)
self.main_layout.addWidget(self.slider)
self.setLayout(self.main_layout)
self.load_files(filenames)
def load_files(self, filenames):
"""Load *filenames* for display."""
self.filenames = filenames
self.slider.setRange(0, len(self.filenames) - 1)
self.slider.setSliderPosition(0)
self.update_image()
def update_image(self):
"""Update the currently display image."""
if self.filenames:
pos = self.slider.value()
image = read_tiff(self.filenames[pos])
self.image_item.setImage(image)
class ImageWindow(object):
"""
Stand-alone window to display image sequences.
"""
global_app = None
def __init__(self, filenames):
self.global_app = QtGui.QApplication.instance() or QtGui.QApplication([])
self.viewer = ImageViewer(filenames)
self.viewer.show()
class OverlapViewer(QtGui.QWidget):
"""
Presents two images by subtracting the flipped second from the first.
To get the current deviation connect to the *slider* attribute's
valueChanged signal.
"""
def __init__(self, parent=None, remove_extrema=False):
super(OverlapViewer, self).__init__()
image_view = pg.ImageView()
image_view.getView().setAspectLocked(True)
self.image_item = image_view.getImageItem()
self.slider = QtGui.QSlider(QtCore.Qt.Horizontal)
self.slider.setRange(0, 0)
self.slider.valueChanged.connect(self.update_image)
self.main_layout = QtGui.QVBoxLayout()
self.main_layout.addWidget(image_view)
self.main_layout.addWidget(self.slider)
self.setLayout(self.main_layout)
self.first, self.second = (None, None)
self.remove_extrema = remove_extrema
self.subtract = True
def set_images(self, first, second):
"""Set *first* and *second* image."""
self.first, self.second = first.T, np.flipud(second.T)
if self.remove_extrema:
self.first = remove_extrema(self.first)
self.second = remove_extrema(self.second)
if self.first.shape != self.second.shape:
LOG.warn("Shape {} of {} is different to {} of {}".
format(self.first.shape, self.first, self.second.shape, self.second))
width = self.first.shape[0]
self.slider.setRange(-width / 2, int(1.5 * width))
self.slider.setSliderPosition(self.first.shape[0] / 2)
self.update_image()
def set_position(self, position):
self.slider.setValue(int(position))
self.update_image()
def update_image(self):
"""Update the current subtraction."""
if self.first is None or self.second is None:
LOG.warn("No images set yet")
else:
pos = self.slider.value()
moved = np.roll(self.second, self.second.shape[0] // 2 - pos, axis=0)
if self.subtract:
self.image_item.setImage(moved - self.first)
else:
self.image_item.setImage(moved + self.first)
class VolumeViewer(QtGui.QWidget):
def __init__(self, step=1, density=1, parent=None):
super(VolumeViewer, self).__init__(parent)
self.volume_view = gl.GLViewWidget()
self.main_layout = QtGui.QVBoxLayout()
self.main_layout.addWidget(self.volume_view)
self.setLayout(self.main_layout)
self.step = step
self.density = density
def load_files(self, filenames):
"""Load *filenames* for display."""
filenames = filenames[::self.step]
num = len(filenames)
first = read_tiff(filenames[0])[::self.step, ::self.step]
width, height = first.shape
data = np.empty((width, height, num), dtype=np.float32)
data[:,:,0] = first
for i, filename in enumerate(filenames[1:]):
data[:, :, i + 1] = read_tiff(filename)[::self.step, ::self.step]
volume = create_volume(data)
dx, dy, dz, _ = volume.shape
volume_item = gl.GLVolumeItem(volume, sliceDensity=self.density)
volume_item.translate(-dx / 2, -dy / 2, -dz / 2)
volume_item.scale(0.05, 0.05, 0.05, local=False)
self.volume_view.addItem(volume_item)
tofu-0.12.0/tox.ini 0000664 0000000 0000000 00000000077 14237137211 0014067 0 ustar 00root root 0000000 0000000 [flake8]
ignore = E402, E721, E722, W503
max-line-length = 100